tensorflow/third_party/mpitrampoline/gen.patch
Clemens Giuliani 8958c652df PR #7849: [XLA:CPU] Add support for cross-process collectives using mpi.
Imported from GitHub PR https://github.com/openxla/xla/pull/7849

Mpi collectives as proposed in https://github.com/google/jax/issues/11182?notification_referrer_id=NT_kwDOAG8zGbIzODQ5MDcxMzM0OjcyODc1Nzc#issuecomment-1851591135.

I only implemented the inter-process communication and this does not yet support more than 1 threads per process. Adding support for multiple threads/devices per process in the future seems quite a bit more involved if one wanted to do it properly.

For MPI I am building and linking against https://github.com/eschnett/MPItrampoline, which dlopens the (wrapped) mpi library at runtime. To wrap and load the desired mpi library one needs compile https://github.com/eschnett/MPIwrapper and set `MPITRAMPOLINE_LIB=/path/to/libmpiwrapper.so`.

@hawkinsp
Copybara import of the project:

--
b74bbb909d902bd30523f943a7c15f2c754cf98a by Clemens Giuliani <clemens@inailuig.it>:

add mpi collectives

--
23508eb46848464f6711dd8f3f91830ea1adb16d by Clemens Giuliani <clemens@inailuig.it>:

add explicit Init and Finalize methods and export them to python

--
bbe5840b8eb56a306a66ed03d701fd8976e01491 by Clemens Giuliani <clemens@inailuig.it>:

add comment

--
38d156282ecc89509f4b21d80db1a37cb290437a by Clemens Giuliani <clemens@inailuig.it>:

fix windows build

--
201f7238f166197ede5cf5d4d70e117a91eddcd7 by Clemens Giuliani <clemens@inailuig.it>:

fmt

--
2784869df650c1c123c346401db2f67cb153b03e by Clemens Giuliani <clemens@inailuig.it>:

bump xla_extension_version

Merging this change closes #7849

PiperOrigin-RevId: 620302264
2024-03-29 13:09:13 -07:00

150 lines
6.0 KiB
Diff

diff --git a/gen/gen_decl.py b/gen/gen_decl.py
index 1005b95..696b4e0 100755
--- a/gen/gen_decl.py
+++ b/gen/gen_decl.py
@@ -9,8 +9,8 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..", "mpiabi"))
from mpi_constants import constants
from mpi_functions import functions
-from mpi_constants_fortran import constants_fortran
-from mpi_functions_fortran import functions_fortran
+# from mpi_constants_fortran import constants_fortran
+# from mpi_functions_fortran import functions_fortran
support_profiling = True
have_weak_symbols = False
@@ -24,7 +24,7 @@ def wrap(line):
lines.append(line)
return "\n".join(lines)
-with open("include/mpi_decl_constants_c.h", "w") as file:
+with open(sys.argv[1], "w") as file:
file.write("// Declare C MPI constants\n")
file.write("\n")
for (tp, nm) in constants:
@@ -32,7 +32,7 @@ with open("include/mpi_decl_constants_c.h", "w") as file:
'mpi_nm': nm}
file.write(Template("extern $mpi_tp MPITRAMPOLINE_CONST $mpi_nm;\n").substitute(subs))
-with open("include/mpi_decl_functions_c.h", "w") as file:
+with open(sys.argv[2], "w") as file:
file.write("// Declare C MPI functions\n")
file.write("\n")
for (tp, nm, args, flags) in functions:
@@ -90,7 +90,7 @@ with open("include/mpi_decl_functions_c.h", "w") as file:
file.write(Template("\n".join(tmpl)).substitute(subs))
file.write("\n")
-with open("include/mpi_decl_constants_fortran.h", "w") as file:
+if False:
file.write("! Declare Fortran MPI constants\n")
file.write("\n")
for (tp, nm) in constants_fortran:
@@ -104,7 +104,7 @@ with open("include/mpi_decl_constants_fortran.h", "w") as file:
file.write("\n".join(map(lambda line: wrap(Template(line).substitute(subs)), tmpl)))
file.write("\n")
-with open("include/mpi_decl_functions_fortran.h", "w") as file:
+if False:
file.write("! Declare Fortran MPI functions\n")
file.write("\n")
for (tp, nm, args) in functions_fortran:
diff --git a/gen/gen_defn.py b/gen/gen_defn.py
index bf31f35..318222e 100755
--- a/gen/gen_defn.py
+++ b/gen/gen_defn.py
@@ -9,14 +9,14 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..", "mpiabi"))
from mpi_constants import constants
from mpi_functions import functions
-from mpi_constants_fortran import constants_fortran
-from mpi_functions_fortran import functions_fortran
+# from mpi_constants_fortran import constants_fortran
+# from mpi_functions_fortran import functions_fortran
support_profiling = True
have_weak_symbols = False
replace_sentinels = False
-with open("src/mpi_defn_constants_c.h", "w") as file:
+with open(sys.argv[1], "w") as file:
file.write("// Define C MPI constants")
file.write("\n")
for (tp, nm) in constants:
@@ -24,7 +24,7 @@ with open("src/mpi_defn_constants_c.h", "w") as file:
'mpi_nm': nm}
file.write(Template("$mpi_tp $mpi_nm = ($mpi_tp)0xdeadbeef;\n").substitute(subs))
-with open("src/mpi_defn_functions_c.h", "w") as file:
+with open(sys.argv[2], "w") as file:
file.write("// Define C MPI functions\n")
file.write("\n")
for (tp, nm, args, flags) in functions:
@@ -89,7 +89,7 @@ with open("src/mpi_defn_functions_c.h", "w") as file:
file.write(Template("\n".join(tmpl)).substitute(subs))
file.write("\n")
-with open("src/mpi_defn_constants_fortran.h", "w") as file:
+if False:
file.write("// Define Fortran MPI constants\n")
file.write("\n")
for (tp, nm) in constants_fortran:
@@ -98,7 +98,7 @@ with open("src/mpi_defn_constants_fortran.h", "w") as file:
# Fortran common blocks with `-march=skylake-avx512` are aligned to 64 bytes
file.write(Template("$mpi_tp $abi_nm __attribute__((__aligned__(64))) = (int)0xdeadbeef;\n").substitute(subs))
-with open("src/mpi_defn_functions_fortran.h", "w") as file:
+if False:
file.write("// Define Fortran MPI functions\n")
file.write("\n")
for (tp, nm, args) in functions_fortran:
diff --git a/gen/gen_init.py b/gen/gen_init.py
index 4939261..0e52822 100755
--- a/gen/gen_init.py
+++ b/gen/gen_init.py
@@ -9,14 +9,14 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..", "mpiabi"))
from mpi_constants import constants
from mpi_functions import functions
-from mpi_constants_fortran import constants_fortran
-from mpi_functions_fortran import functions_fortran
+# from mpi_constants_fortran import constants_fortran
+# from mpi_functions_fortran import functions_fortran
support_profiling = True
have_weak_symbols = False
replace_sentinels = False
-with open("src/mpi_init_constants_c.h", "w") as file:
+with open(sys.argv[1], "w") as file:
file.write("// Initialize C MPI constants")
file.write("\n")
for (tp, nm) in constants:
@@ -25,7 +25,7 @@ with open("src/mpi_init_constants_c.h", "w") as file:
'abi_nm': re.sub(r"MPI(X?)_", r"MPI\1ABI_", nm)}
file.write(Template("$mpi_nm = *($mpi_tp const *)get_symbol(handle, \"$abi_nm\");\n").substitute(subs))
-with open("src/mpi_init_functions_c.h", "w") as file:
+with open(sys.argv[2], "w") as file:
file.write("// Initialize C MPI functions\n")
file.write("\n")
for (tp, nm, args, flags) in functions:
@@ -39,7 +39,7 @@ with open("src/mpi_init_functions_c.h", "w") as file:
subs['anm{0}'.format(i)] = anm
file.write(Template("$abi_nm = get_symbol(handle, \"$abi_nm\");\n").substitute(subs))
-with open("src/mpi_init_constants_fortran.h", "w") as file:
+if False:
file.write("// Initialize Fortran MPI constants\n")
file.write("\n")
for (tp, nm) in constants_fortran:
@@ -47,7 +47,7 @@ with open("src/mpi_init_constants_fortran.h", "w") as file:
'abi_nm': re.sub(r"MPI(X?)_", r"MPI\1ABI_", nm).lower() + "_"}
file.write(Template("$abi_nm = *($abi_tp const*)get_symbol(handle, \"$abi_nm\");\n").substitute(subs))
-with open("src/mpi_init_functions_fortran.h", "w") as file:
+if False:
file.write("// Initialize Fortran MPI functions\n")
file.write("\n")
for (tp, nm, args) in functions_fortran: