mirror of
https://github.com/tensorflow/tensorflow.git
synced 2024-11-21 21:05:19 +00:00
8958c652df
Imported from GitHub PR https://github.com/openxla/xla/pull/7849 Mpi collectives as proposed in https://github.com/google/jax/issues/11182?notification_referrer_id=NT_kwDOAG8zGbIzODQ5MDcxMzM0OjcyODc1Nzc#issuecomment-1851591135. I only implemented the inter-process communication and this does not yet support more than 1 threads per process. Adding support for multiple threads/devices per process in the future seems quite a bit more involved if one wanted to do it properly. For MPI I am building and linking against https://github.com/eschnett/MPItrampoline, which dlopens the (wrapped) mpi library at runtime. To wrap and load the desired mpi library one needs compile https://github.com/eschnett/MPIwrapper and set `MPITRAMPOLINE_LIB=/path/to/libmpiwrapper.so`. @hawkinsp Copybara import of the project: -- b74bbb909d902bd30523f943a7c15f2c754cf98a by Clemens Giuliani <clemens@inailuig.it>: add mpi collectives -- 23508eb46848464f6711dd8f3f91830ea1adb16d by Clemens Giuliani <clemens@inailuig.it>: add explicit Init and Finalize methods and export them to python -- bbe5840b8eb56a306a66ed03d701fd8976e01491 by Clemens Giuliani <clemens@inailuig.it>: add comment -- 38d156282ecc89509f4b21d80db1a37cb290437a by Clemens Giuliani <clemens@inailuig.it>: fix windows build -- 201f7238f166197ede5cf5d4d70e117a91eddcd7 by Clemens Giuliani <clemens@inailuig.it>: fmt -- 2784869df650c1c123c346401db2f67cb153b03e by Clemens Giuliani <clemens@inailuig.it>: bump xla_extension_version Merging this change closes #7849 PiperOrigin-RevId: 620302264
150 lines
6.0 KiB
Diff
150 lines
6.0 KiB
Diff
diff --git a/gen/gen_decl.py b/gen/gen_decl.py
|
|
index 1005b95..696b4e0 100755
|
|
--- a/gen/gen_decl.py
|
|
+++ b/gen/gen_decl.py
|
|
@@ -9,8 +9,8 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..", "mpiabi"))
|
|
|
|
from mpi_constants import constants
|
|
from mpi_functions import functions
|
|
-from mpi_constants_fortran import constants_fortran
|
|
-from mpi_functions_fortran import functions_fortran
|
|
+# from mpi_constants_fortran import constants_fortran
|
|
+# from mpi_functions_fortran import functions_fortran
|
|
|
|
support_profiling = True
|
|
have_weak_symbols = False
|
|
@@ -24,7 +24,7 @@ def wrap(line):
|
|
lines.append(line)
|
|
return "\n".join(lines)
|
|
|
|
-with open("include/mpi_decl_constants_c.h", "w") as file:
|
|
+with open(sys.argv[1], "w") as file:
|
|
file.write("// Declare C MPI constants\n")
|
|
file.write("\n")
|
|
for (tp, nm) in constants:
|
|
@@ -32,7 +32,7 @@ with open("include/mpi_decl_constants_c.h", "w") as file:
|
|
'mpi_nm': nm}
|
|
file.write(Template("extern $mpi_tp MPITRAMPOLINE_CONST $mpi_nm;\n").substitute(subs))
|
|
|
|
-with open("include/mpi_decl_functions_c.h", "w") as file:
|
|
+with open(sys.argv[2], "w") as file:
|
|
file.write("// Declare C MPI functions\n")
|
|
file.write("\n")
|
|
for (tp, nm, args, flags) in functions:
|
|
@@ -90,7 +90,7 @@ with open("include/mpi_decl_functions_c.h", "w") as file:
|
|
file.write(Template("\n".join(tmpl)).substitute(subs))
|
|
file.write("\n")
|
|
|
|
-with open("include/mpi_decl_constants_fortran.h", "w") as file:
|
|
+if False:
|
|
file.write("! Declare Fortran MPI constants\n")
|
|
file.write("\n")
|
|
for (tp, nm) in constants_fortran:
|
|
@@ -104,7 +104,7 @@ with open("include/mpi_decl_constants_fortran.h", "w") as file:
|
|
file.write("\n".join(map(lambda line: wrap(Template(line).substitute(subs)), tmpl)))
|
|
file.write("\n")
|
|
|
|
-with open("include/mpi_decl_functions_fortran.h", "w") as file:
|
|
+if False:
|
|
file.write("! Declare Fortran MPI functions\n")
|
|
file.write("\n")
|
|
for (tp, nm, args) in functions_fortran:
|
|
diff --git a/gen/gen_defn.py b/gen/gen_defn.py
|
|
index bf31f35..318222e 100755
|
|
--- a/gen/gen_defn.py
|
|
+++ b/gen/gen_defn.py
|
|
@@ -9,14 +9,14 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..", "mpiabi"))
|
|
|
|
from mpi_constants import constants
|
|
from mpi_functions import functions
|
|
-from mpi_constants_fortran import constants_fortran
|
|
-from mpi_functions_fortran import functions_fortran
|
|
+# from mpi_constants_fortran import constants_fortran
|
|
+# from mpi_functions_fortran import functions_fortran
|
|
|
|
support_profiling = True
|
|
have_weak_symbols = False
|
|
replace_sentinels = False
|
|
|
|
-with open("src/mpi_defn_constants_c.h", "w") as file:
|
|
+with open(sys.argv[1], "w") as file:
|
|
file.write("// Define C MPI constants")
|
|
file.write("\n")
|
|
for (tp, nm) in constants:
|
|
@@ -24,7 +24,7 @@ with open("src/mpi_defn_constants_c.h", "w") as file:
|
|
'mpi_nm': nm}
|
|
file.write(Template("$mpi_tp $mpi_nm = ($mpi_tp)0xdeadbeef;\n").substitute(subs))
|
|
|
|
-with open("src/mpi_defn_functions_c.h", "w") as file:
|
|
+with open(sys.argv[2], "w") as file:
|
|
file.write("// Define C MPI functions\n")
|
|
file.write("\n")
|
|
for (tp, nm, args, flags) in functions:
|
|
@@ -89,7 +89,7 @@ with open("src/mpi_defn_functions_c.h", "w") as file:
|
|
file.write(Template("\n".join(tmpl)).substitute(subs))
|
|
file.write("\n")
|
|
|
|
-with open("src/mpi_defn_constants_fortran.h", "w") as file:
|
|
+if False:
|
|
file.write("// Define Fortran MPI constants\n")
|
|
file.write("\n")
|
|
for (tp, nm) in constants_fortran:
|
|
@@ -98,7 +98,7 @@ with open("src/mpi_defn_constants_fortran.h", "w") as file:
|
|
# Fortran common blocks with `-march=skylake-avx512` are aligned to 64 bytes
|
|
file.write(Template("$mpi_tp $abi_nm __attribute__((__aligned__(64))) = (int)0xdeadbeef;\n").substitute(subs))
|
|
|
|
-with open("src/mpi_defn_functions_fortran.h", "w") as file:
|
|
+if False:
|
|
file.write("// Define Fortran MPI functions\n")
|
|
file.write("\n")
|
|
for (tp, nm, args) in functions_fortran:
|
|
diff --git a/gen/gen_init.py b/gen/gen_init.py
|
|
index 4939261..0e52822 100755
|
|
--- a/gen/gen_init.py
|
|
+++ b/gen/gen_init.py
|
|
@@ -9,14 +9,14 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..", "mpiabi"))
|
|
|
|
from mpi_constants import constants
|
|
from mpi_functions import functions
|
|
-from mpi_constants_fortran import constants_fortran
|
|
-from mpi_functions_fortran import functions_fortran
|
|
+# from mpi_constants_fortran import constants_fortran
|
|
+# from mpi_functions_fortran import functions_fortran
|
|
|
|
support_profiling = True
|
|
have_weak_symbols = False
|
|
replace_sentinels = False
|
|
|
|
-with open("src/mpi_init_constants_c.h", "w") as file:
|
|
+with open(sys.argv[1], "w") as file:
|
|
file.write("// Initialize C MPI constants")
|
|
file.write("\n")
|
|
for (tp, nm) in constants:
|
|
@@ -25,7 +25,7 @@ with open("src/mpi_init_constants_c.h", "w") as file:
|
|
'abi_nm': re.sub(r"MPI(X?)_", r"MPI\1ABI_", nm)}
|
|
file.write(Template("$mpi_nm = *($mpi_tp const *)get_symbol(handle, \"$abi_nm\");\n").substitute(subs))
|
|
|
|
-with open("src/mpi_init_functions_c.h", "w") as file:
|
|
+with open(sys.argv[2], "w") as file:
|
|
file.write("// Initialize C MPI functions\n")
|
|
file.write("\n")
|
|
for (tp, nm, args, flags) in functions:
|
|
@@ -39,7 +39,7 @@ with open("src/mpi_init_functions_c.h", "w") as file:
|
|
subs['anm{0}'.format(i)] = anm
|
|
file.write(Template("$abi_nm = get_symbol(handle, \"$abi_nm\");\n").substitute(subs))
|
|
|
|
-with open("src/mpi_init_constants_fortran.h", "w") as file:
|
|
+if False:
|
|
file.write("// Initialize Fortran MPI constants\n")
|
|
file.write("\n")
|
|
for (tp, nm) in constants_fortran:
|
|
@@ -47,7 +47,7 @@ with open("src/mpi_init_constants_fortran.h", "w") as file:
|
|
'abi_nm': re.sub(r"MPI(X?)_", r"MPI\1ABI_", nm).lower() + "_"}
|
|
file.write(Template("$abi_nm = *($abi_tp const*)get_symbol(handle, \"$abi_nm\");\n").substitute(subs))
|
|
|
|
-with open("src/mpi_init_functions_fortran.h", "w") as file:
|
|
+if False:
|
|
file.write("// Initialize Fortran MPI functions\n")
|
|
file.write("\n")
|
|
for (tp, nm, args) in functions_fortran:
|