From f72b6cc7ca65ec9e017094ab908eb02a4a18cf65 Mon Sep 17 00:00:00 2001
From: Karol Herbst <kherbst@redhat.com>
Date: Wed, 17 Jan 2024 16:17:10 +0100
Subject: [PATCH] linker: allow linking functions with different pointer
 arguments

Since llvm-17 there are no typed pointers and hte SPIRV-LLVM-Translator
doesn't know the function signature of imported functions.

I'm investigating different ways of solving this problem and adding an
option to work around it inside `spirv-link` is one of those.

The code is almost complete, just I'm having troubles constructing the
bitcast to cast the pointer parameters to the final type.

Closes: https://github.com/KhronosGroup/SPIRV-LLVM-Translator/issues/2153
---
 include/spirv-tools/linker.hpp |   7 ++-
 source/link/linker.cpp         | 105 ++++++++++++++++++++++++++++++---
 tools/link/linker.cpp          |  24 +++++---
 3 files changed, 119 insertions(+), 17 deletions(-)

diff --git a/include/spirv-tools/linker.hpp b/include/spirv-tools/linker.hpp
index 6ba6e9654a..9037b94889 100644
--- a/include/spirv-tools/linker.hpp
+++ b/include/spirv-tools/linker.hpp
@@ -16,7 +16,6 @@
 #define INCLUDE_SPIRV_TOOLS_LINKER_HPP_
 
 #include <cstdint>
-
 #include <memory>
 #include <vector>
 
@@ -63,11 +62,17 @@ class SPIRV_TOOLS_EXPORT LinkerOptions {
     use_highest_version_ = use_highest_vers;
   }
 
+  bool GetAllowPtrTypeMismatch() const { return allow_ptr_type_mismatch_; }
+  void SetAllowPtrTypeMismatch(bool allow_ptr_type_mismatch) {
+    allow_ptr_type_mismatch_ = allow_ptr_type_mismatch;
+  }
+
  private:
   bool create_library_{false};
   bool verify_ids_{false};
   bool allow_partial_linkage_{false};
   bool use_highest_version_{false};
+  bool allow_ptr_type_mismatch_{false};
 };
 
 // Links one or more SPIR-V modules into a new SPIR-V module. That is, combine
diff --git a/source/link/linker.cpp b/source/link/linker.cpp
index 9e6520eaa8..e6aa72e32c 100644
--- a/source/link/linker.cpp
+++ b/source/link/linker.cpp
@@ -31,6 +31,7 @@
 #include "source/opt/build_module.h"
 #include "source/opt/compact_ids_pass.h"
 #include "source/opt/decoration_manager.h"
+#include "source/opt/ir_builder.h"
 #include "source/opt/ir_loader.h"
 #include "source/opt/pass_manager.h"
 #include "source/opt/remove_duplicates_pass.h"
@@ -46,12 +47,14 @@ namespace spvtools {
 namespace {
 
 using opt::Instruction;
+using opt::InstructionBuilder;
 using opt::IRContext;
 using opt::Module;
 using opt::PassManager;
 using opt::RemoveDuplicatesPass;
 using opt::analysis::DecorationManager;
 using opt::analysis::DefUseManager;
+using opt::analysis::Function;
 using opt::analysis::Type;
 using opt::analysis::TypeManager;
 
@@ -126,6 +129,7 @@ spv_result_t GetImportExportPairs(const MessageConsumer& consumer,
 // checked.
 spv_result_t CheckImportExportCompatibility(const MessageConsumer& consumer,
                                             const LinkageTable& linkings_to_do,
+                                            bool allow_ptr_type_mismatch,
                                             opt::IRContext* context);
 
 // Remove linkage specific instructions, such as prototypes of imported
@@ -502,6 +506,7 @@ spv_result_t GetImportExportPairs(const MessageConsumer& consumer,
 
 spv_result_t CheckImportExportCompatibility(const MessageConsumer& consumer,
                                             const LinkageTable& linkings_to_do,
+                                            bool allow_ptr_type_mismatch,
                                             opt::IRContext* context) {
   spv_position_t position = {};
 
@@ -513,7 +518,34 @@ spv_result_t CheckImportExportCompatibility(const MessageConsumer& consumer,
         type_manager.GetType(linking_entry.imported_symbol.type_id);
     Type* exported_symbol_type =
         type_manager.GetType(linking_entry.exported_symbol.type_id);
-    if (!(*imported_symbol_type == *exported_symbol_type))
+    if (!(*imported_symbol_type == *exported_symbol_type)) {
+      Function* imported_symbol_type_func = imported_symbol_type->AsFunction();
+      Function* exported_symbol_type_func = exported_symbol_type->AsFunction();
+
+      if (imported_symbol_type_func && exported_symbol_type_func) {
+        const auto& imported_params = imported_symbol_type_func->param_types();
+        const auto& exported_params = exported_symbol_type_func->param_types();
+        // allow_ptr_type_mismatch allows linking functions where the pointer
+        // type of arguments doesn't match. Everything else still needs to be
+        // equal. This is to workaround LLVM-17+ not having typed pointers and
+        // generated SPIR-Vs not knowing the actual pointer types in some cases.
+        if (allow_ptr_type_mismatch &&
+            imported_params.size() == exported_params.size()) {
+          bool correct = true;
+          for (size_t i = 0; i < imported_params.size(); i++) {
+            const auto& imported_param = imported_params[i];
+            const auto& exported_param = exported_params[i];
+
+            if (!imported_param->IsSame(exported_param) &&
+                (imported_param->kind() != Type::kPointer ||
+                 exported_param->kind() != Type::kPointer)) {
+              correct = false;
+              break;
+            }
+          }
+          if (correct) continue;
+        }
+      }
       return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_BINARY)
              << "Type mismatch on symbol \""
              << linking_entry.imported_symbol.name
@@ -521,6 +553,7 @@ spv_result_t CheckImportExportCompatibility(const MessageConsumer& consumer,
              << linking_entry.imported_symbol.id
              << " and exported variable/function %"
              << linking_entry.exported_symbol.id << ".";
+    }
   }
 
   // Ensure the import and export decorations are similar
@@ -696,6 +729,57 @@ spv_result_t VerifyLimits(const MessageConsumer& consumer,
   return SPV_SUCCESS;
 }
 
+spv_result_t FixFunctionCallTypes(opt::IRContext& context,
+                                  const LinkageTable& linkings) {
+  auto mod = context.module();
+  const auto type_manager = context.get_type_mgr();
+  const auto def_use_mgr = context.get_def_use_mgr();
+
+  for (auto& func : *mod) {
+    func.ForEachInst([&](Instruction* inst) {
+      if (inst->opcode() != spv::Op::OpFunctionCall) return;
+      opt::Operand& target = inst->GetInOperand(0);
+
+      // only fix calls to imported functions
+      auto linking = std::find_if(
+          linkings.begin(), linkings.end(), [&](const auto& entry) {
+            return entry.exported_symbol.id == target.AsId();
+          });
+      if (linking == linkings.end()) return;
+
+      auto builder = InstructionBuilder(&context, inst);
+      for (uint32_t i = 1; i < inst->NumInOperands(); ++i) {
+        auto exported_func_param =
+            def_use_mgr->GetDef(linking->exported_symbol.parameter_ids[i - 1]);
+        const Type* target_type =
+            type_manager->GetType(exported_func_param->type_id());
+        if (target_type->kind() != Type::kPointer) continue;
+
+        opt::Operand& arg = inst->GetInOperand(i);
+        const Type* param_type =
+            type_manager->GetType(def_use_mgr->GetDef(arg.AsId())->type_id());
+
+        // No need to cast if it already matches
+        if (*param_type == *target_type) continue;
+
+        auto new_id = context.TakeNextId();
+
+        // cast to the expected pointer type
+        builder.AddInstruction(MakeUnique<opt::Instruction>(
+            &context, spv::Op::OpBitcast, exported_func_param->type_id(),
+            new_id,
+            opt::Instruction::OperandList(
+                {{SPV_OPERAND_TYPE_ID, {arg.AsId()}}})));
+
+        inst->SetInOperand(i, {new_id});
+      }
+    });
+  }
+  context.InvalidateAnalyses(opt::IRContext::kAnalysisDefUse |
+                             opt::IRContext::kAnalysisInstrToBlockMapping);
+  return SPV_SUCCESS;
+}
+
 }  // namespace
 
 spv_result_t Link(const Context& context,
@@ -789,8 +873,9 @@ spv_result_t Link(const Context& context, const uint32_t* const* binaries,
   if (res != SPV_SUCCESS) return res;
 
   // Phase 6: Ensure the import and export have the same types and decorations.
-  res =
-      CheckImportExportCompatibility(consumer, linkings_to_do, &linked_context);
+  res = CheckImportExportCompatibility(consumer, linkings_to_do,
+                                       options.GetAllowPtrTypeMismatch(),
+                                       &linked_context);
   if (res != SPV_SUCCESS) return res;
 
   // Phase 7: Remove all names and decorations of import variables/functions
@@ -815,21 +900,27 @@ spv_result_t Link(const Context& context, const uint32_t* const* binaries,
                                           &linked_context);
   if (res != SPV_SUCCESS) return res;
 
-  // Phase 10: Compact the IDs used in the module
+  // Phase 10: Optionally fix function call types
+  if (options.GetAllowPtrTypeMismatch()) {
+    res = FixFunctionCallTypes(linked_context, linkings_to_do);
+    if (res != SPV_SUCCESS) return res;
+  }
+
+  // Phase 11: Compact the IDs used in the module
   manager.AddPass<opt::CompactIdsPass>();
   pass_res = manager.Run(&linked_context);
   if (pass_res == opt::Pass::Status::Failure) return SPV_ERROR_INVALID_DATA;
 
-  // Phase 11: Recompute EntryPoint variables
+  // Phase 12: Recompute EntryPoint variables
   manager.AddPass<opt::RemoveUnusedInterfaceVariablesPass>();
   pass_res = manager.Run(&linked_context);
   if (pass_res == opt::Pass::Status::Failure) return SPV_ERROR_INVALID_DATA;
 
-  // Phase 12: Warn if SPIR-V limits were exceeded
+  // Phase 13: Warn if SPIR-V limits were exceeded
   res = VerifyLimits(consumer, linked_context);
   if (res != SPV_SUCCESS) return res;
 
-  // Phase 13: Output the module
+  // Phase 14: Output the module
   linked_context.module()->ToBinary(linked_binary, true);
 
   return SPV_SUCCESS;
diff --git a/tools/link/linker.cpp b/tools/link/linker.cpp
index f3898aab0d..7dbd8596fa 100644
--- a/tools/link/linker.cpp
+++ b/tools/link/linker.cpp
@@ -48,6 +48,10 @@ Options (in lexicographical order):
   --allow-partial-linkage
                Allow partial linkage by accepting imported symbols to be
                unresolved.
+  --allow-pointer-mismatch
+               Allow pointer function parameters to mismatch the target link
+               target. This is useful to workaround lost correct parameter type
+               information due to LLVM's opaque pointers.
   --create-library
                Link the binaries into a library, keeping all exported symbols.
   -h, --help
@@ -77,15 +81,16 @@ Options (in lexicographical order):
 }  // namespace
 
 // clang-format off
-FLAG_SHORT_bool(  h,                     /* default_value= */ false,               /* required= */ false);
-FLAG_LONG_bool(   help,                  /* default_value= */ false,               /* required= */false);
-FLAG_LONG_bool(   version,               /* default_value= */ false,               /* required= */ false);
-FLAG_LONG_bool(   verify_ids,            /* default_value= */ false,               /* required= */ false);
-FLAG_LONG_bool(   create_library,        /* default_value= */ false,               /* required= */ false);
-FLAG_LONG_bool(   allow_partial_linkage, /* default_value= */ false,               /* required= */ false);
-FLAG_SHORT_string(o,                     /* default_value= */ "",                  /* required= */ false);
-FLAG_LONG_string( target_env,            /* default_value= */ kDefaultEnvironment, /* required= */ false);
-FLAG_LONG_bool(   use_highest_version,   /* default_value= */ false,               /* required= */ false);
+FLAG_SHORT_bool(  h,                      /* default_value= */ false,               /* required= */ false);
+FLAG_LONG_bool(   help,                   /* default_value= */ false,               /* required= */ false);
+FLAG_LONG_bool(   version,                /* default_value= */ false,               /* required= */ false);
+FLAG_LONG_bool(   verify_ids,             /* default_value= */ false,               /* required= */ false);
+FLAG_LONG_bool(   create_library,         /* default_value= */ false,               /* required= */ false);
+FLAG_LONG_bool(   allow_partial_linkage,  /* default_value= */ false,               /* required= */ false);
+FLAG_LONG_bool(   allow_pointer_mismatch, /* default_value= */ false,               /* required= */ false);
+FLAG_SHORT_string(o,                      /* default_value= */ "",                  /* required= */ false);
+FLAG_LONG_string( target_env,             /* default_value= */ kDefaultEnvironment, /* required= */ false);
+FLAG_LONG_bool(   use_highest_version,    /* default_value= */ false,               /* required= */ false);
 // clang-format on
 
 int main(int, const char* argv[]) {
@@ -126,6 +131,7 @@ int main(int, const char* argv[]) {
 
   spvtools::LinkerOptions options;
   options.SetAllowPartialLinkage(flags::allow_partial_linkage.value());
+  options.SetAllowPtrTypeMismatch(flags::allow_pointer_mismatch.value());
   options.SetCreateLibrary(flags::create_library.value());
   options.SetVerifyIds(flags::verify_ids.value());
   options.SetUseHighestVersion(flags::use_highest_version.value());
