From b2959739ed9f2f305269b91b1aa5167d4a1a9688 Mon Sep 17 00:00:00 2001
From: Zebediah Figura <zfigura@codeweavers.com>
Date: Wed, 5 Apr 2023 14:09:16 -0500
Subject: [PATCH] vkd3d-shader/hlsl: Rewrite the register allocator to allow
 allocating in multiple passes.

We will need this in order to allocate some "special" registers: ps_1_* output, sincos output, etc.
---
 libs/vkd3d-shader/hlsl_codegen.c | 129 +++++++++++++++++--------------
 1 file changed, 71 insertions(+), 58 deletions(-)

diff --git a/libs/vkd3d-shader/hlsl_codegen.c b/libs/vkd3d-shader/hlsl_codegen.c
index a2f3242a..ea343edc 100644
--- a/libs/vkd3d-shader/hlsl_codegen.c
+++ b/libs/vkd3d-shader/hlsl_codegen.c
@@ -2745,43 +2745,61 @@ static void compute_liveness(struct hlsl_ctx *ctx, struct hlsl_ir_function_decl
 
 struct register_allocator
 {
-    size_t size;
-    uint32_t reg_count;
-    struct
+    size_t count, capacity;
+
+    /* Highest register index that has been allocated.
+     * Used to declare sm4 temp count. */
+    uint32_t max_reg;
+
+    struct allocation
     {
-        /* 0 if not live yet. */
-        unsigned int last_read;
-    } *regs;
+        uint32_t reg;
+        unsigned int writemask;
+        unsigned int first_write, last_read;
+    } *allocations;
 };
 
-static unsigned int get_available_writemask(struct register_allocator *allocator,
-        unsigned int first_write, unsigned int component_idx, unsigned int reg_size)
+static unsigned int get_available_writemask(const struct register_allocator *allocator,
+        unsigned int first_write, unsigned int last_read, uint32_t reg_idx)
 {
-    unsigned int i, writemask = 0, count = 0;
+    unsigned int writemask = VKD3DSP_WRITEMASK_ALL;
+    size_t i;
 
-    for (i = 0; i < 4; ++i)
+    for (i = 0; i < allocator->count; ++i)
     {
-        if (allocator->regs[component_idx + i].last_read <= first_write)
-        {
-            writemask |= 1u << i;
-            if (++count == reg_size)
-                return writemask;
-        }
+        const struct allocation *allocation = &allocator->allocations[i];
+
+        /* We do not overlap if first write == last read:
+         * this is the case where we are allocating the result of that
+         * expression, e.g. "add r0, r0, r1". */
+
+        if (allocation->reg == reg_idx
+                && first_write < allocation->last_read && last_read > allocation->first_write)
+            writemask &= ~allocation->writemask;
+
+        if (!writemask)
+            break;
     }
 
-    return 0;
+    return writemask;
 }
 
-static bool resize_liveness(struct hlsl_ctx *ctx, struct register_allocator *allocator, size_t new_count)
+static void record_allocation(struct hlsl_ctx *ctx, struct register_allocator *allocator,
+        uint32_t reg_idx, unsigned int writemask, unsigned int first_write, unsigned int last_read)
 {
-    size_t old_capacity = allocator->size;
+    struct allocation *allocation;
 
-    if (!hlsl_array_reserve(ctx, (void **)&allocator->regs, &allocator->size, new_count, sizeof(*allocator->regs)))
-        return false;
+    if (!hlsl_array_reserve(ctx, (void **)&allocator->allocations, &allocator->capacity,
+            allocator->count + 1, sizeof(*allocator->allocations)))
+        return;
 
-    if (allocator->size > old_capacity)
-        memset(allocator->regs + old_capacity, 0, (allocator->size - old_capacity) * sizeof(*allocator->regs));
-    return true;
+    allocation = &allocator->allocations[allocator->count++];
+    allocation->reg = reg_idx;
+    allocation->writemask = writemask;
+    allocation->first_write = first_write;
+    allocation->last_read = last_read;
+
+    allocator->max_reg = max(allocator->max_reg, reg_idx);
 }
 
 /* reg_size is the number of register components to be reserved, while component_count is the number
@@ -2791,42 +2809,39 @@ static struct hlsl_reg allocate_register(struct hlsl_ctx *ctx, struct register_a
         unsigned int first_write, unsigned int last_read, unsigned int reg_size,
         unsigned int component_count)
 {
-    unsigned int component_idx, writemask, i;
     struct hlsl_reg ret = {0};
+    unsigned int writemask;
+    uint32_t reg_idx;
 
     assert(component_count <= reg_size);
 
-    for (component_idx = 0; component_idx < allocator->size; component_idx += 4)
+    for (reg_idx = 0;; ++reg_idx)
     {
-        if ((writemask = get_available_writemask(allocator, first_write, component_idx, reg_size)))
+        writemask = get_available_writemask(allocator, first_write, last_read, reg_idx);
+
+        if (vkd3d_popcount(writemask) >= reg_size)
+        {
+            writemask = hlsl_combine_writemasks(writemask, (1u << reg_size) - 1);
             break;
+        }
     }
-    if (component_idx == allocator->size)
-    {
-        if (!resize_liveness(ctx, allocator, component_idx + 4))
-            return ret;
-        writemask = (1u << reg_size) - 1;
-    }
-    for (i = 0; i < 4; ++i)
-    {
-        if (writemask & (1u << i))
-            allocator->regs[component_idx + i].last_read = last_read;
-    }
-    ret.id = component_idx / 4;
+
+    record_allocation(ctx, allocator, reg_idx, writemask, first_write, last_read);
+
+    ret.id = reg_idx;
     ret.writemask = hlsl_combine_writemasks(writemask, (1u << component_count) - 1);
     ret.allocated = true;
-    allocator->reg_count = max(allocator->reg_count, ret.id + 1);
     return ret;
 }
 
-static bool is_range_available(struct register_allocator *allocator, unsigned int first_write,
-        unsigned int component_idx, unsigned int reg_size)
+static bool is_range_available(const struct register_allocator *allocator,
+        unsigned int first_write, unsigned int last_read, uint32_t reg_idx, unsigned int reg_size)
 {
-    unsigned int i;
+    uint32_t i;
 
-    for (i = 0; i < reg_size; i += 4)
+    for (i = 0; i < (reg_size / 4); ++i)
     {
-        if (!get_available_writemask(allocator, first_write, component_idx + i, 4))
+        if (get_available_writemask(allocator, first_write, last_read, reg_idx + i) != VKD3DSP_WRITEMASK_ALL)
             return false;
     }
     return true;
@@ -2835,23 +2850,21 @@ static bool is_range_available(struct register_allocator *allocator, unsigned in
 static struct hlsl_reg allocate_range(struct hlsl_ctx *ctx, struct register_allocator *allocator,
         unsigned int first_write, unsigned int last_read, unsigned int reg_size)
 {
-    unsigned int i, component_idx;
     struct hlsl_reg ret = {0};
+    uint32_t reg_idx;
+    unsigned int i;
 
-    for (component_idx = 0; component_idx < allocator->size; component_idx += 4)
+    for (reg_idx = 0;; ++reg_idx)
     {
-        if (is_range_available(allocator, first_write, component_idx,
-                min(reg_size, allocator->size - component_idx)))
+        if (is_range_available(allocator, first_write, last_read, reg_idx, reg_size))
             break;
     }
-    if (!resize_liveness(ctx, allocator, component_idx + reg_size))
-        return ret;
 
-    for (i = 0; i < reg_size; ++i)
-        allocator->regs[component_idx + i].last_read = last_read;
-    ret.id = component_idx / 4;
+    for (i = 0; i < reg_size / 4; ++i)
+        record_allocation(ctx, allocator, reg_idx + i, VKD3DSP_WRITEMASK_ALL, first_write, last_read);
+
+    ret.id = reg_idx;
     ret.allocated = true;
-    allocator->reg_count = max(allocator->reg_count, ret.id + align(reg_size, 4));
     return ret;
 }
 
@@ -3075,7 +3088,7 @@ static void allocate_const_registers(struct hlsl_ctx *ctx, struct hlsl_ir_functi
         }
     }
 
-    vkd3d_free(allocator.regs);
+    vkd3d_free(allocator.allocations);
 }
 
 /* Simple greedy temporary register allocation pass that just assigns a unique
@@ -3086,8 +3099,8 @@ static void allocate_temp_registers(struct hlsl_ctx *ctx, struct hlsl_ir_functio
 {
     struct register_allocator allocator = {0};
     allocate_temp_registers_recurse(ctx, &entry_func->body, &allocator);
-    ctx->temp_count = allocator.reg_count;
-    vkd3d_free(allocator.regs);
+    ctx->temp_count = allocator.max_reg + 1;
+    vkd3d_free(allocator.allocations);
 }
 
 static void allocate_semantic_register(struct hlsl_ctx *ctx, struct hlsl_ir_var *var, unsigned int *counter, bool output)