mirror of
https://github.com/encounter/decomp.me.git
synced 2026-03-30 11:06:27 -07:00
571 lines
18 KiB
Python
571 lines
18 KiB
Python
import base64
|
|
import hashlib
|
|
import io
|
|
import json
|
|
import logging
|
|
import re
|
|
import zipfile
|
|
from datetime import datetime
|
|
from typing import Any, Dict, Optional
|
|
|
|
import django_filters
|
|
from coreapp import compilers, platforms
|
|
from django.core.files import File
|
|
from django.db.models import F, FloatField, When, Case, Value
|
|
from django.db.models.functions import Cast
|
|
from django.db.models.query import QuerySet
|
|
|
|
from django.http import HttpResponse, QueryDict
|
|
from django.utils.decorators import method_decorator
|
|
|
|
from rest_framework import filters, mixins, serializers, status
|
|
from rest_framework.decorators import action
|
|
from rest_framework.exceptions import APIException
|
|
from rest_framework.pagination import CursorPagination
|
|
from rest_framework.response import Response
|
|
from rest_framework.viewsets import GenericViewSet
|
|
|
|
from ..compiler_wrapper import CompilationResult, CompilerWrapper, DiffResult
|
|
from ..decompiler_wrapper import DecompilerWrapper
|
|
from ..decorators.cache import globally_cacheable
|
|
from ..decorators.django import condition
|
|
from ..diff_wrapper import DiffWrapper
|
|
from ..error import CompilationError, DiffError
|
|
from ..filters.search import NonEmptySearchFilter
|
|
from ..flags import Language
|
|
from ..libraries import Library
|
|
from ..middleware import Request
|
|
from ..models.preset import Preset
|
|
from ..models.scratch import Asm, Assembly, Scratch
|
|
from ..platforms import Platform
|
|
from ..serializers import (
|
|
ClaimableScratchSerializer,
|
|
ScratchCreateSerializer,
|
|
ScratchSerializer,
|
|
TerseScratchSerializer,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ProjectNotMemberException(APIException):
|
|
status_code = status.HTTP_403_FORBIDDEN
|
|
default_detail = "You must be a maintainer of the project to perform this action."
|
|
|
|
|
|
def get_db_asm(request_asm: str) -> Asm:
|
|
h = hashlib.sha256(request_asm.encode()).hexdigest()
|
|
asm, _ = Asm.objects.get_or_create(
|
|
hash=h,
|
|
defaults={
|
|
"data": request_asm,
|
|
},
|
|
)
|
|
return asm
|
|
|
|
|
|
# 1 MB
|
|
MAX_FILE_SIZE = 1000 * 1024
|
|
|
|
|
|
def cache_object(platform: Platform, file: File[Any]) -> Assembly:
|
|
# Validate file size
|
|
if file.size > MAX_FILE_SIZE:
|
|
raise serializers.ValidationError(
|
|
f"Object must be less than {MAX_FILE_SIZE} bytes"
|
|
)
|
|
|
|
# Check if ELF, Mach-O, or PE
|
|
obj_bytes = file.read()
|
|
is_elf = obj_bytes[:4] == b"\x7fELF"
|
|
is_macho = obj_bytes[:4] == b"\xcf\xfa\xed\xfe"
|
|
is_coff = obj_bytes[:2] in (b"\x4c\x01", b"\x64\x86")
|
|
if not (is_elf or is_macho or is_coff):
|
|
raise serializers.ValidationError("Object must be an ELF, Mach-O, or COFF file")
|
|
|
|
assembly, _ = Assembly.objects.get_or_create(
|
|
hash=hashlib.sha256(obj_bytes).hexdigest(),
|
|
defaults={
|
|
"arch": platform.arch,
|
|
"elf_object": obj_bytes,
|
|
},
|
|
)
|
|
return assembly
|
|
|
|
|
|
def compile_scratch(scratch: Scratch) -> CompilationResult:
|
|
try:
|
|
return CompilerWrapper.compile_code(
|
|
compilers.from_id(scratch.compiler),
|
|
scratch.compiler_flags,
|
|
scratch.source_code,
|
|
scratch.context,
|
|
scratch.diff_label,
|
|
tuple(scratch.libraries),
|
|
)
|
|
except (CompilationError, APIException) as e:
|
|
return CompilationResult(b"", str(e))
|
|
|
|
|
|
def diff_compilation(scratch: Scratch, compilation: CompilationResult) -> DiffResult:
|
|
try:
|
|
return DiffWrapper.diff(
|
|
scratch.target_assembly,
|
|
platforms.from_id(scratch.platform),
|
|
scratch.diff_label,
|
|
bytes(compilation.elf_object),
|
|
diff_flags=scratch.diff_flags,
|
|
)
|
|
except DiffError as e:
|
|
return DiffResult(None, str(e))
|
|
|
|
|
|
def update_scratch_score(scratch: Scratch, diff: DiffResult) -> None:
|
|
"""
|
|
Given a scratch and a diff, update the scratch's score
|
|
"""
|
|
|
|
if diff.result is None:
|
|
return
|
|
score = diff.result.get("current_score", scratch.score)
|
|
max_score = diff.result.get("max_score", scratch.max_score)
|
|
if score != scratch.score or max_score != scratch.max_score:
|
|
scratch.score = score
|
|
scratch.max_score = max_score
|
|
scratch.save(update_fields=["score", "max_score"])
|
|
|
|
|
|
def compile_scratch_update_score(scratch: Scratch) -> None:
|
|
"""
|
|
Initialize the scratch's score and ignore errors should they occur
|
|
"""
|
|
|
|
compilation = compile_scratch(scratch)
|
|
try:
|
|
diff = diff_compilation(scratch, compilation)
|
|
update_scratch_score(scratch, diff)
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
def scratch_last_modified(
|
|
request: Request, pk: Optional[str] = None
|
|
) -> Optional[datetime]:
|
|
scratch: Optional[Scratch] = Scratch.objects.filter(slug=pk).first()
|
|
if scratch:
|
|
return scratch.last_updated
|
|
else:
|
|
return None
|
|
|
|
|
|
scratch_condition = condition(last_modified_func=scratch_last_modified)
|
|
|
|
|
|
def is_contentful_asm(asm: Optional[Asm]) -> bool:
|
|
if asm is None:
|
|
return False
|
|
|
|
asm_text = asm.data.strip()
|
|
|
|
if asm_text == "" or asm_text == "nop":
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
def update_needs_recompile(partial: Dict[str, Any]) -> bool:
|
|
recompile_params = [
|
|
"compiler",
|
|
"compiler_flags",
|
|
"diff_flags",
|
|
"diff_label",
|
|
"source_code",
|
|
"context",
|
|
]
|
|
|
|
for param in recompile_params:
|
|
if param in partial:
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
def create_scratch(data: Dict[str, Any], allow_project: bool = False) -> Scratch:
|
|
create_ser = ScratchCreateSerializer(data=data)
|
|
create_ser.is_valid(raise_exception=True)
|
|
data = create_ser.validated_data
|
|
|
|
platform: Optional[Platform] = data.get("platform")
|
|
compiler = compilers.from_id(data["compiler"])
|
|
|
|
if not platform:
|
|
platform = compiler.platform
|
|
|
|
target_asm: str = data.get("target_asm", "")
|
|
target_obj: File[Any] | None = data.get("target_obj")
|
|
context: str = data["context"]
|
|
diff_label: str = data.get("diff_label", "")
|
|
|
|
if target_obj:
|
|
asm = None
|
|
assembly = cache_object(platform, target_obj)
|
|
else:
|
|
asm = get_db_asm(target_asm)
|
|
assembly = CompilerWrapper.assemble_asm(platform, asm)
|
|
|
|
source_code = data.get("source_code")
|
|
if asm and not source_code:
|
|
default_source_code = f"void {diff_label or 'func'}(void) {{\n // ...\n}}\n"
|
|
source_code = DecompilerWrapper.decompile(
|
|
default_source_code, platform, asm.data, context, compiler
|
|
)
|
|
|
|
compiler_flags = data.get("compiler_flags", "")
|
|
compiler_flags = CompilerWrapper.filter_compiler_flags(compiler_flags)
|
|
|
|
diff_flags = data.get("diff_flags", [])
|
|
|
|
preset_id: Optional[str] = None
|
|
if data.get("preset"):
|
|
preset: Preset = data["preset"]
|
|
preset_id = str(preset.id)
|
|
|
|
name = data.get("name", diff_label) or "Untitled"
|
|
|
|
libraries = [
|
|
Library(**lib) if isinstance(lib, dict) else lib for lib in data["libraries"]
|
|
]
|
|
|
|
ser = ScratchSerializer(
|
|
data={
|
|
"name": name,
|
|
"compiler": compiler.id,
|
|
"compiler_flags": compiler_flags,
|
|
"diff_flags": diff_flags,
|
|
"preset": preset_id,
|
|
"context": context,
|
|
"diff_label": diff_label,
|
|
"source_code": source_code,
|
|
}
|
|
)
|
|
ser.is_valid(raise_exception=True)
|
|
scratch = ser.save(
|
|
target_assembly=assembly,
|
|
platform=platform.id,
|
|
libraries=libraries,
|
|
)
|
|
|
|
compile_scratch_update_score(scratch)
|
|
|
|
return scratch
|
|
|
|
|
|
class ScratchPagination(CursorPagination):
|
|
ordering = "-creation_time"
|
|
page_size = 10
|
|
page_size_query_param = "page_size"
|
|
max_page_size = 100
|
|
|
|
|
|
@method_decorator(globally_cacheable(max_age=5, stale_while_revalidate=1), name="list")
|
|
class ScratchViewSet(
|
|
mixins.CreateModelMixin,
|
|
mixins.DestroyModelMixin,
|
|
mixins.RetrieveModelMixin,
|
|
mixins.UpdateModelMixin,
|
|
mixins.ListModelMixin,
|
|
GenericViewSet, # type: ignore
|
|
):
|
|
match_percent = Case(
|
|
When(max_score__lte=0, then=Value(0.0)),
|
|
When(score__lt=0, then=Value(0.0)),
|
|
When(score__gt=F("max_score"), then=Value(0.0)),
|
|
When(score=0, then=Value(1.0)),
|
|
When(match_override=True, then=Value(1.0)),
|
|
default=1.0 - (F("score") / Cast("max_score", FloatField())),
|
|
)
|
|
|
|
queryset = (
|
|
Scratch.objects.all()
|
|
.select_related("owner__user__github")
|
|
.annotate(match_percent=match_percent)
|
|
)
|
|
pagination_class = ScratchPagination
|
|
filterset_fields = ["platform", "compiler", "preset"]
|
|
filter_backends = [
|
|
django_filters.rest_framework.DjangoFilterBackend,
|
|
NonEmptySearchFilter,
|
|
filters.OrderingFilter,
|
|
]
|
|
search_fields = ["name", "diff_label"]
|
|
ordering_fields = ["creation_time", "last_updated", "score", "match_percent"]
|
|
|
|
def get_serializer_class(self) -> type[serializers.ModelSerializer[Scratch]]:
|
|
if self.action == "list":
|
|
return TerseScratchSerializer
|
|
else:
|
|
return ScratchSerializer
|
|
|
|
@scratch_condition
|
|
def retrieve(self, request: Request, *args: Any, **kwargs: Any) -> Response:
|
|
return super().retrieve(request, *args, **kwargs)
|
|
|
|
def create(self, request: Any, *args: Any, **kwargs: Any) -> Response:
|
|
scratch = create_scratch(request.data)
|
|
|
|
return Response(
|
|
ClaimableScratchSerializer(scratch, context={"request": request}).data,
|
|
status=status.HTTP_201_CREATED,
|
|
)
|
|
|
|
# TODO: possibly move this logic into ScratchSerializer.save method
|
|
def update(self, request: Any, *args: Any, **kwargs: Any) -> Response:
|
|
# Check permission
|
|
scratch = self.get_object()
|
|
if scratch.owner != request.profile:
|
|
response = self.retrieve(request, *args, **kwargs)
|
|
response.status_code = status.HTTP_403_FORBIDDEN
|
|
return response
|
|
|
|
response = super().update(request, *args, **kwargs)
|
|
|
|
if update_needs_recompile(request.data):
|
|
scratch = self.get_object()
|
|
compile_scratch_update_score(scratch)
|
|
return Response(
|
|
ScratchSerializer(scratch, context={"request": request}).data
|
|
)
|
|
|
|
return response
|
|
|
|
def destroy(self, request: Any, *args: Any, **kwargs: Any) -> Response:
|
|
# Check permission
|
|
scratch = self.get_object()
|
|
if scratch.owner != request.profile and not request.profile.is_staff():
|
|
response = self.retrieve(request, *args, **kwargs)
|
|
response.status_code = status.HTTP_403_FORBIDDEN
|
|
return response
|
|
|
|
response = super().destroy(request, *args, **kwargs)
|
|
|
|
return response
|
|
|
|
# POST on compile takes a partial and does not update the scratch's compilation status
|
|
@action(detail=True, methods=["GET", "POST"])
|
|
def compile(self, request: Request, pk: str) -> Response:
|
|
scratch: Scratch = self.get_object()
|
|
|
|
# Apply partial
|
|
include_objects = False
|
|
if request.method == "POST":
|
|
# TODO: use a serializer w/ validation
|
|
if "compiler" in request.data:
|
|
scratch.compiler = request.data["compiler"]
|
|
if "compiler_flags" in request.data:
|
|
scratch.compiler_flags = request.data["compiler_flags"]
|
|
if "diff_flags" in request.data:
|
|
scratch.diff_flags = request.data["diff_flags"]
|
|
if "diff_label" in request.data:
|
|
scratch.diff_label = request.data["diff_label"]
|
|
if "source_code" in request.data:
|
|
scratch.source_code = request.data["source_code"]
|
|
if "context" in request.data:
|
|
scratch.context = request.data["context"]
|
|
if "libraries" in request.data:
|
|
libs = [Library(**lib) for lib in request.data["libraries"]]
|
|
scratch.libraries = libs
|
|
if "include_objects" in request.data:
|
|
include_objects = request.data["include_objects"]
|
|
|
|
compilation = compile_scratch(scratch)
|
|
diff = diff_compilation(scratch, compilation)
|
|
|
|
if request.method == "GET":
|
|
update_scratch_score(scratch, diff)
|
|
|
|
compiler_output = ""
|
|
if compilation.errors:
|
|
compiler_output += compilation.errors + "\n"
|
|
if diff.errors:
|
|
compiler_output += diff.errors + "\n"
|
|
|
|
response = {
|
|
"diff_output": diff.result,
|
|
"compiler_output": compiler_output,
|
|
"success": compilation.elf_object is not None
|
|
and len(compilation.elf_object) > 0,
|
|
}
|
|
|
|
if include_objects or request.method == "GET":
|
|
|
|
def to_base64(obj: bytes) -> str:
|
|
return base64.b64encode(obj).decode("utf-8")
|
|
|
|
response["left_object"] = to_base64(scratch.target_assembly.elf_object)
|
|
response["right_object"] = to_base64(compilation.elf_object)
|
|
|
|
return Response(response)
|
|
|
|
@action(detail=True, methods=["POST"])
|
|
def decompile(self, request: Request, pk: str) -> Response:
|
|
scratch: Scratch = self.get_object()
|
|
if scratch.target_assembly.source_asm is None:
|
|
return Response(
|
|
{
|
|
"decompilation": "This scratch cannot currently be run through the decompiler because it was created via object file."
|
|
}
|
|
)
|
|
|
|
context = request.data.get("context", scratch.context)
|
|
compiler = compilers.from_id(request.data.get("compiler", scratch.compiler))
|
|
|
|
platform = platforms.from_id(scratch.platform)
|
|
|
|
decompilation = DecompilerWrapper.decompile(
|
|
"",
|
|
platform,
|
|
scratch.target_assembly.source_asm.data,
|
|
context,
|
|
compiler,
|
|
)
|
|
|
|
return Response({"decompilation": decompilation})
|
|
|
|
@action(detail=True, methods=["POST"])
|
|
def claim(self, request: Request, pk: str) -> Response:
|
|
scratch: Scratch = self.get_object()
|
|
token = request.data.get("token")
|
|
|
|
if not scratch.is_claimable():
|
|
return Response({"success": False})
|
|
|
|
if scratch.claim_token and scratch.claim_token != token:
|
|
return Response({"success": False})
|
|
|
|
profile = request.profile
|
|
|
|
logger.debug(f"Granting ownership of scratch {scratch} to {profile}")
|
|
|
|
scratch.owner = profile
|
|
scratch.claim_token = None
|
|
scratch.save()
|
|
|
|
return Response({"success": True})
|
|
|
|
@action(detail=True, methods=["POST"])
|
|
def fork(self, request: Request, pk: str) -> Response:
|
|
parent: Scratch = self.get_object()
|
|
|
|
# TODO Needed for test_fork_scratch test?
|
|
if isinstance(request.data, QueryDict):
|
|
request_data = request.data.dict()
|
|
else:
|
|
request_data = request.data
|
|
|
|
parent_data = ScratchSerializer(parent, context={"request": request}).data
|
|
fork_data = {**parent_data, **request_data}
|
|
|
|
ser = ScratchSerializer(data=fork_data, context={"request": request})
|
|
ser.is_valid(raise_exception=True)
|
|
|
|
libraries = [Library(**lib) for lib in ser.validated_data["libraries"]]
|
|
new_scratch = ser.save(
|
|
parent=parent,
|
|
target_assembly=parent.target_assembly,
|
|
platform=parent.platform,
|
|
libraries=libraries,
|
|
)
|
|
|
|
compile_scratch_update_score(new_scratch)
|
|
|
|
return Response(
|
|
ClaimableScratchSerializer(new_scratch, context={"request": request}).data,
|
|
status=status.HTTP_201_CREATED,
|
|
)
|
|
|
|
@action(detail=True)
|
|
@scratch_condition
|
|
def export(self, request: Request, pk: str) -> HttpResponse:
|
|
scratch: Scratch = self.get_object()
|
|
|
|
metadata = ScratchSerializer(scratch, context={"request": request}).data
|
|
metadata.pop("source_code")
|
|
metadata.pop("context")
|
|
|
|
zip_bytes = io.BytesIO()
|
|
with zipfile.ZipFile(
|
|
zip_bytes, mode="w", compression=zipfile.ZIP_DEFLATED
|
|
) as zip_f:
|
|
zip_f.writestr("metadata.json", json.dumps(metadata, indent=4))
|
|
if scratch.target_assembly.source_asm is not None:
|
|
zip_f.writestr("target.s", scratch.target_assembly.source_asm.data)
|
|
zip_f.writestr("target.o", scratch.target_assembly.elf_object)
|
|
|
|
language = compilers.from_id(scratch.compiler).language
|
|
src_ext = Language(language).get_file_extension()
|
|
zip_f.writestr(f"code.{src_ext}", scratch.source_code)
|
|
if scratch.context:
|
|
zip_f.writestr(f"ctx.{src_ext}", scratch.context)
|
|
|
|
if request.GET.get("target_only") != "1":
|
|
compilation = compile_scratch(scratch)
|
|
if compilation.elf_object:
|
|
zip_f.writestr("current.o", compilation.elf_object)
|
|
|
|
# Prevent possible header injection attacks
|
|
safe_name = re.sub(r"[^a-zA-Z0-9_:]", "_", scratch.name)[:64]
|
|
|
|
return HttpResponse(
|
|
zip_bytes.getvalue(),
|
|
headers={
|
|
"Content-Type": "application/zip",
|
|
"Content-Disposition": f"attachment; filename={safe_name}.zip",
|
|
},
|
|
)
|
|
|
|
@action(detail=True)
|
|
def family(self, request: Request, pk: str) -> Response:
|
|
scratch: Scratch = self.get_object()
|
|
|
|
subqueries: list[QuerySet["Scratch"]] = []
|
|
|
|
if is_contentful_asm(scratch.target_assembly.source_asm):
|
|
assert scratch.target_assembly.source_asm is not None
|
|
subqueries.append(
|
|
Scratch.objects.filter(
|
|
target_assembly__source_asm__hash=scratch.target_assembly.source_asm.hash
|
|
)
|
|
)
|
|
elif (
|
|
scratch.target_assembly.elf_object is not None
|
|
and len(scratch.target_assembly.elf_object) > 0
|
|
):
|
|
subqueries.append(
|
|
Scratch.objects.filter(
|
|
target_assembly__hash=scratch.target_assembly.hash,
|
|
diff_label=scratch.diff_label,
|
|
)
|
|
)
|
|
else:
|
|
subqueries.append(Scratch.objects.filter(slug=scratch.slug))
|
|
|
|
if scratch.family_id is not None:
|
|
subqueries.append(Scratch.objects.filter(family_id=scratch.family_id))
|
|
|
|
if scratch.parent_id is not None:
|
|
subqueries.append(Scratch.objects.filter(parent_id=scratch.parent_id))
|
|
|
|
# Avoid 'ORDER BY not allowed in subqueries of compound statements.'
|
|
subqueries = [sq.order_by() for sq in subqueries]
|
|
|
|
if len(subqueries) == 1:
|
|
family = subqueries[0]
|
|
else:
|
|
family = subqueries[0].union(*subqueries[1:])
|
|
|
|
family = family.order_by("creation_time")
|
|
|
|
return Response(
|
|
TerseScratchSerializer(family, many=True, context={"request": request}).data
|
|
)
|