mirror of
https://github.com/encounter/decomp.me.git
synced 2026-03-30 11:06:27 -07:00
fd246922e9
* Bump backend deps * Add timeout to sentry.io publisher * Appease mypy
356 lines
12 KiB
Python
356 lines
12 KiB
Python
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
|
|
|
from django.contrib.auth.models import User
|
|
from html_json_forms.serializers import JSONFormSerializer
|
|
from rest_framework import serializers
|
|
from rest_framework.exceptions import APIException
|
|
from rest_framework.relations import PKOnlyObject, SlugRelatedField
|
|
|
|
from coreapp import platforms
|
|
|
|
from . import compilers
|
|
from .flags import LanguageFlagSet
|
|
from .libraries import Library
|
|
from .models.github import GitHubUser
|
|
from .models.preset import Preset
|
|
from .models.profile import Profile
|
|
from .models.project import Project, ProjectMember
|
|
from .models.scratch import Scratch
|
|
|
|
|
|
def serialize_profile(profile: Profile) -> Dict[str, Any]:
|
|
if profile.user is None:
|
|
return {
|
|
"is_anonymous": True,
|
|
"id": profile.id,
|
|
"is_online": profile.is_online(),
|
|
"is_admin": False,
|
|
"username": f"{profile.pseudonym} (anon)",
|
|
"frog_color": profile.get_frog_color(),
|
|
}
|
|
else:
|
|
user = profile.user
|
|
|
|
gh_user: Optional[GitHubUser] = getattr(user, "github", None)
|
|
if not gh_user:
|
|
# NOTE: All models with an "owner" should fetch related "owner__user__github"
|
|
# in order to avoid N+1 queries when a Profile is serialized for each object.
|
|
gh_user = GitHubUser.objects.filter(user=user).first()
|
|
|
|
return {
|
|
"is_anonymous": False,
|
|
"id": profile.id,
|
|
"is_online": profile.is_online(),
|
|
"is_admin": user.is_staff,
|
|
"username": user.username,
|
|
"github_id": gh_user.github_id if gh_user else None,
|
|
}
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
ProfileFieldBaseClass = serializers.RelatedField[Profile, str, Dict[str, Any]]
|
|
else:
|
|
ProfileFieldBaseClass = serializers.RelatedField
|
|
|
|
|
|
class ProfileField(ProfileFieldBaseClass):
|
|
def to_representation(self, value: Profile | PKOnlyObject) -> dict[str, Any]:
|
|
if isinstance(value, Profile):
|
|
return serialize_profile(value)
|
|
# fallback
|
|
return super().to_representation(value)
|
|
|
|
|
|
class LibrarySerializer(serializers.Serializer[Library]):
|
|
name = serializers.CharField()
|
|
version = serializers.CharField()
|
|
|
|
|
|
class TinyPresetSerializer(serializers.ModelSerializer[Preset]):
|
|
class Meta:
|
|
model = Preset
|
|
fields = ["id", "name"]
|
|
|
|
|
|
class TersePresetSerializer(serializers.ModelSerializer[Preset]):
|
|
libraries = serializers.ListField(child=LibrarySerializer(), default=list)
|
|
owner = ProfileField(read_only=True)
|
|
|
|
class Meta:
|
|
model = Preset
|
|
fields = [
|
|
"id",
|
|
"name",
|
|
"owner",
|
|
"platform",
|
|
"compiler",
|
|
"compiler_flags",
|
|
"diff_flags",
|
|
"libraries",
|
|
]
|
|
|
|
|
|
class PresetSerializer(serializers.ModelSerializer[Preset]):
|
|
libraries = serializers.ListField(child=LibrarySerializer(), default=list)
|
|
num_scratches = serializers.SerializerMethodField()
|
|
owner = ProfileField(read_only=True)
|
|
|
|
class Meta:
|
|
model = Preset
|
|
fields = [
|
|
"id",
|
|
"name",
|
|
"platform",
|
|
"compiler",
|
|
"assembler_flags",
|
|
"compiler_flags",
|
|
"diff_flags",
|
|
"decompiler_flags",
|
|
"libraries",
|
|
"num_scratches",
|
|
"owner",
|
|
]
|
|
read_only_fields = [
|
|
"creation_time",
|
|
"last_updated",
|
|
"owner",
|
|
]
|
|
|
|
def get_num_scratches(self, preset: Preset) -> int:
|
|
return Scratch.objects.filter(preset=preset).count()
|
|
|
|
def validate_platform(self, platform: str) -> str:
|
|
try:
|
|
platforms.from_id(platform)
|
|
except Exception:
|
|
raise serializers.ValidationError(f"Unknown platform: {platform}")
|
|
return platform
|
|
|
|
def validate_compiler(self, compiler: str) -> str:
|
|
try:
|
|
compilers.from_id(compiler)
|
|
except Exception:
|
|
raise serializers.ValidationError(f"Unknown compiler: {compiler}")
|
|
return compiler
|
|
|
|
def validate(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
|
compiler = compilers.from_id(data["compiler"])
|
|
platform = platforms.from_id(data["platform"])
|
|
|
|
if compiler.platform != platform:
|
|
raise serializers.ValidationError(
|
|
f"Compiler {compiler.id} is not compatible with platform {platform.id}"
|
|
)
|
|
return data
|
|
|
|
|
|
class ScratchCreateSerializer(serializers.Serializer[None]):
|
|
name = serializers.CharField(allow_blank=True, required=False)
|
|
compiler = serializers.CharField(allow_blank=True, required=False)
|
|
platform = serializers.CharField(allow_blank=True, required=False)
|
|
compiler_flags = serializers.CharField(allow_blank=True, required=False)
|
|
diff_flags = serializers.JSONField(required=False)
|
|
preset = serializers.PrimaryKeyRelatedField(
|
|
required=False, queryset=Preset.objects.all()
|
|
)
|
|
source_code = serializers.CharField(allow_blank=True, required=False)
|
|
target_asm = serializers.CharField(allow_blank=True, required=False)
|
|
target_obj = serializers.FileField(allow_null=True, required=False)
|
|
context = serializers.CharField(allow_blank=True) # type: ignore
|
|
diff_label = serializers.CharField(allow_blank=True, required=False)
|
|
libraries = serializers.JSONField(default=list) # type: ignore
|
|
|
|
project = serializers.CharField(allow_blank=False, required=False)
|
|
rom_address = serializers.IntegerField(required=False)
|
|
|
|
def validate_platform(self, platform: str) -> str:
|
|
try:
|
|
platforms.from_id(platform)
|
|
except Exception:
|
|
raise serializers.ValidationError(f"Unknown platform: {platform}")
|
|
return platform
|
|
|
|
def validate_compiler(self, compiler: str) -> str:
|
|
try:
|
|
compilers.from_id(compiler)
|
|
except Exception:
|
|
raise serializers.ValidationError(f"Unknown compiler: {compiler}")
|
|
return compiler
|
|
|
|
def validate_libraries(
|
|
self, libraries: list[dict[str, str]]
|
|
) -> list[dict[str, str]]:
|
|
for library in libraries:
|
|
for key in ["name", "version"]:
|
|
if key not in library:
|
|
raise serializers.ValidationError(
|
|
f"Library {library} is missing '{key}' key"
|
|
)
|
|
return libraries
|
|
|
|
def validate(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
|
if "preset" in data:
|
|
preset: Preset = data["preset"]
|
|
# Preset dictates platform
|
|
data["platform"] = platforms.from_id(preset.platform)
|
|
|
|
if "compiler" not in data or not data["compiler"]:
|
|
data["compiler"] = preset.compiler
|
|
|
|
if "compiler_flags" not in data or not data["compiler_flags"]:
|
|
data["compiler_flags"] = preset.compiler_flags
|
|
|
|
if "diff_flags" not in data or not data["diff_flags"]:
|
|
data["diff_flags"] = preset.diff_flags
|
|
|
|
if "libraries" not in data or not data["libraries"]:
|
|
data["libraries"] = preset.libraries
|
|
else:
|
|
if "compiler" not in data or not data["compiler"]:
|
|
raise serializers.ValidationError(
|
|
"Compiler must be provided when preset is omitted"
|
|
)
|
|
|
|
try:
|
|
compiler = compilers.from_id(data["compiler"])
|
|
except APIException:
|
|
raise serializers.ValidationError(
|
|
f"Unknown compiler: {data['compiler']}"
|
|
)
|
|
|
|
if "platform" not in data or not data["platform"]:
|
|
data["platform"] = compiler.platform
|
|
else:
|
|
try:
|
|
platform = platforms.from_id(data["platform"])
|
|
except APIException:
|
|
raise serializers.ValidationError(
|
|
f"Unknown platform: {data['platform']}"
|
|
)
|
|
|
|
if compiler.platform != platform:
|
|
raise serializers.ValidationError(
|
|
f"Compiler {compiler.id} is not compatible with platform {platform.id}"
|
|
)
|
|
data["platform"] = platform
|
|
return data
|
|
|
|
|
|
class ScratchSerializer(serializers.ModelSerializer[Scratch]):
|
|
slug = serializers.SlugField(read_only=True)
|
|
parent = serializers.PrimaryKeyRelatedField(read_only=True) # type: ignore
|
|
owner = ProfileField(read_only=True)
|
|
source_code = serializers.CharField(allow_blank=True, trim_whitespace=False)
|
|
context = serializers.CharField(allow_blank=True, trim_whitespace=False) # type: ignore
|
|
language = serializers.SerializerMethodField()
|
|
libraries = serializers.ListField(child=LibrarySerializer(), default=list)
|
|
preset = serializers.PrimaryKeyRelatedField(
|
|
required=False, allow_null=True, queryset=Preset.objects.all()
|
|
)
|
|
|
|
class Meta:
|
|
model = Scratch
|
|
exclude = [
|
|
"claim_token",
|
|
"target_assembly",
|
|
]
|
|
read_only_fields = [
|
|
"parent",
|
|
"owner",
|
|
"last_updated",
|
|
"creation_time",
|
|
"platform",
|
|
]
|
|
|
|
def get_language(self, scratch: Scratch) -> str:
|
|
"""
|
|
Strategy for extracting a scratch's language:
|
|
- If the scratch's compiler has a LanguageFlagSet in its flags, attempt to match a language flag against that
|
|
- Otherwise, fallback to the compiler's default language
|
|
"""
|
|
compiler = compilers.from_id(scratch.compiler)
|
|
language_flag_set = next(
|
|
(i for i in compiler.flags if isinstance(i, LanguageFlagSet)),
|
|
None,
|
|
)
|
|
|
|
if language_flag_set:
|
|
matches = [
|
|
(flag, language)
|
|
for flag, language in language_flag_set.flags.items()
|
|
if flag in scratch.compiler_flags
|
|
]
|
|
|
|
if matches:
|
|
# taking the longest avoids detecting C++ as C
|
|
longest_match = max(matches, key=lambda m: len(m[0]))
|
|
return longest_match[1].value
|
|
|
|
# If we're here, either the compiler doesn't have a LanguageFlagSet, or the scratch doesn't
|
|
# have a flag within it.
|
|
# Either way: fall back to the compiler default.
|
|
return compiler.language.value
|
|
|
|
|
|
class TerseScratchSerializer(ScratchSerializer):
|
|
owner = ProfileField(read_only=True)
|
|
|
|
class Meta:
|
|
model = Scratch
|
|
fields = [
|
|
"slug",
|
|
"owner",
|
|
"last_updated",
|
|
"creation_time",
|
|
"platform",
|
|
"compiler",
|
|
"preset",
|
|
"name",
|
|
"score",
|
|
"max_score",
|
|
"match_override",
|
|
"parent",
|
|
"preset",
|
|
"libraries",
|
|
]
|
|
|
|
|
|
# On initial creation, include the "claim_token" field.
|
|
class ClaimableScratchSerializer(ScratchSerializer):
|
|
claim_token = serializers.CharField(read_only=True)
|
|
|
|
class Meta(ScratchSerializer.Meta):
|
|
exclude = [
|
|
field for field in ScratchSerializer.Meta.exclude if field != "claim_token"
|
|
]
|
|
|
|
|
|
class ProjectSerializer(JSONFormSerializer, serializers.ModelSerializer[Project]):
|
|
slug = serializers.SlugField()
|
|
|
|
class Meta:
|
|
model = Project
|
|
exclude: List[str] = []
|
|
|
|
def create(self, validated_data: Any) -> Project:
|
|
project = Project.objects.create(**validated_data)
|
|
return project
|
|
|
|
def update(self, instance: Project, validated_data: Any) -> Project:
|
|
for attr, value in validated_data.items():
|
|
setattr(instance, attr, value)
|
|
instance.save()
|
|
return instance
|
|
|
|
|
|
class ProjectMemberSerializer(serializers.ModelSerializer[ProjectMember]):
|
|
username = SlugRelatedField(
|
|
source="user",
|
|
slug_field="username",
|
|
queryset=User.objects.all(),
|
|
)
|
|
|
|
class Meta:
|
|
model = ProjectMember
|
|
fields = ["username"]
|