tensorflow/third_party/implib_so/make_stub.py
Peter Hawkins 5b7a8b07e6 Generate CUDA stubs using implib.so, rather than by writing C++ stubs.
We load CUDA libraries lazily using dlopen()/dlsym() primarily to comply with the manylinux rules for Python wheels, which require that libraries in a wheel only link directly against an allowlist of libraries. In order to access CUDA using dlopen()/dlsym() without changing any of our CUDA-using code, TSL contains stub implementations of CUDA APIs that, when invoked, load the relevant library, obtain the requested symbol using dlsym(), and call into that symbol.

The current CUDA stub libraries were constructed using a tool based on clang inside Google, which parses the CUDA headers and generates stub code for each API. It is therefore difficult to update the stubs without access to that tool. Each stub generated by the tool is a C++ function, which is very verbose. Further, a number of manual edits are required to the generated code, making maintenance tedious.

However, there is a better way. implib.so
(https://github.com/yugr/Implib.so) is a tool for automatically generating stubs from a .so file. This tool is considerably simpler because it generates a stub using assembly language, in which it turns out we do not need to know the type signature of the function being called. A completely generic trampoline function will do, with little or no bespoke knowledge of each function. We can adapt it to solve our CUDA stub problem.

implib-gen.py, which is the tool implib.so provides, isn't perfect for our needs, because it requires access to the .so file at stub generation time, which we don't have in our Bazel build. Instead, we can split it into two phases:

get_symbols.py, which, given a .so file extracts a list of public symbols that should be present in a stub. That list of symbols is checked into the TSL tree.
make_stub.py, which, given a list of symbols, generates trampolines for each function.
This change changes the TSL CUDA build to use make_stub.py to generate stubs from the list of symbols at Bazel build time, allowing us to delete over 130k lines of autogenerated C++ stub code.

PiperOrigin-RevId: 567635940
2023-09-22 09:09:08 -07:00

69 lines
2.0 KiB
Python

"""Given a list of symbols, generates a stub."""
import argparse
import configparser
import os
import string
from bazel_tools.tools.python.runfiles import runfiles
r = runfiles.Create()
def main():
parser = argparse.ArgumentParser(
description='Generates stubs for CUDA libraries.'
)
parser.add_argument('symbols', help='File containing a list of symbols.')
parser.add_argument(
'--outdir', '-o', help='Path to create wrapper at', default='.'
)
parser.add_argument(
'--target',
help='Target platform name, e.g. x86_64, aarch64.',
required=True,
)
args = parser.parse_args()
config_path = r.Rlocation(f'implib_so/arch/{args.target}/config.ini')
table_path = r.Rlocation(f'implib_so/arch/{args.target}/table.S.tpl')
trampoline_path = r.Rlocation(
f'implib_so/arch/{args.target}/trampoline.S.tpl'
)
cfg = configparser.ConfigParser(inline_comment_prefixes=';')
cfg.read(config_path)
ptr_size = int(cfg['Arch']['PointerSize'])
with open(args.symbols, 'r') as f:
funs = [s.strip() for s in f.readlines()]
# Generate assembly code, containing a table for the resolved symbols and the
# trampolines.
lib_name, _ = os.path.splitext(os.path.basename(args.symbols))
with open(os.path.join(args.outdir, f'{lib_name}.tramp.S'), 'w') as f:
with open(table_path, 'r') as t:
table_text = string.Template(t.read()).substitute(
lib_suffix=lib_name, table_size=ptr_size * (len(funs) + 1)
)
f.write(table_text)
with open(trampoline_path, 'r') as t:
tramp_tpl = string.Template(t.read())
for i, name in enumerate(funs):
tramp_text = tramp_tpl.substitute(
lib_suffix=lib_name, sym=name, offset=i * ptr_size, number=i
)
f.write(tramp_text)
# Generates a list of symbols, formatted as a list of C++ strings.
with open(os.path.join(args.outdir, f'{lib_name}.inc'), 'w') as f:
sym_names = ''.join(f' "{name}",\n' for name in funs)
f.write(sym_names)
if __name__ == '__main__':
main()