mirror of
https://github.com/tensorflow/tensorflow.git
synced 2024-11-21 21:05:19 +00:00
8958c652df
Imported from GitHub PR https://github.com/openxla/xla/pull/7849 Mpi collectives as proposed in https://github.com/google/jax/issues/11182?notification_referrer_id=NT_kwDOAG8zGbIzODQ5MDcxMzM0OjcyODc1Nzc#issuecomment-1851591135. I only implemented the inter-process communication and this does not yet support more than 1 threads per process. Adding support for multiple threads/devices per process in the future seems quite a bit more involved if one wanted to do it properly. For MPI I am building and linking against https://github.com/eschnett/MPItrampoline, which dlopens the (wrapped) mpi library at runtime. To wrap and load the desired mpi library one needs compile https://github.com/eschnett/MPIwrapper and set `MPITRAMPOLINE_LIB=/path/to/libmpiwrapper.so`. @hawkinsp Copybara import of the project: -- b74bbb909d902bd30523f943a7c15f2c754cf98a by Clemens Giuliani <clemens@inailuig.it>: add mpi collectives -- 23508eb46848464f6711dd8f3f91830ea1adb16d by Clemens Giuliani <clemens@inailuig.it>: add explicit Init and Finalize methods and export them to python -- bbe5840b8eb56a306a66ed03d701fd8976e01491 by Clemens Giuliani <clemens@inailuig.it>: add comment -- 38d156282ecc89509f4b21d80db1a37cb290437a by Clemens Giuliani <clemens@inailuig.it>: fix windows build -- 201f7238f166197ede5cf5d4d70e117a91eddcd7 by Clemens Giuliani <clemens@inailuig.it>: fmt -- 2784869df650c1c123c346401db2f67cb153b03e by Clemens Giuliani <clemens@inailuig.it>: bump xla_extension_version Merging this change closes #7849 PiperOrigin-RevId: 620302264
19 lines
801 B
Python
19 lines
801 B
Python
"""Provides the repository macro to import mpitrampoline."""
|
|
|
|
load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")
|
|
|
|
def repo():
|
|
"""Imports mpitrampoline."""
|
|
|
|
MPITRAMPOLINE_COMMIT = "25efb0f7a4cd00ed82bafb8b1a6285fc50d297ed"
|
|
MPITRAMPOLINE_SHA256 = "5a36656205c472bdb639bffebb0f014523b32dda0c2cbedd9ce7abfc9e879e84"
|
|
|
|
tf_http_archive(
|
|
name = "mpitrampoline",
|
|
sha256 = MPITRAMPOLINE_SHA256,
|
|
strip_prefix = "MPItrampoline-{commit}".format(commit = MPITRAMPOLINE_COMMIT),
|
|
urls = tf_mirror_urls("https://github.com/eschnett/mpitrampoline/archive/{commit}.tar.gz".format(commit = MPITRAMPOLINE_COMMIT)),
|
|
patch_file = ["//third_party/mpitrampoline:gen.patch"],
|
|
build_file = "//third_party/mpitrampoline:mpitrampoline.BUILD",
|
|
)
|