tensorflow/third_party/mpitrampoline/workspace.bzl
Clemens Giuliani 8958c652df PR #7849: [XLA:CPU] Add support for cross-process collectives using mpi.
Imported from GitHub PR https://github.com/openxla/xla/pull/7849

Mpi collectives as proposed in https://github.com/google/jax/issues/11182?notification_referrer_id=NT_kwDOAG8zGbIzODQ5MDcxMzM0OjcyODc1Nzc#issuecomment-1851591135.

I only implemented the inter-process communication and this does not yet support more than 1 threads per process. Adding support for multiple threads/devices per process in the future seems quite a bit more involved if one wanted to do it properly.

For MPI I am building and linking against https://github.com/eschnett/MPItrampoline, which dlopens the (wrapped) mpi library at runtime. To wrap and load the desired mpi library one needs compile https://github.com/eschnett/MPIwrapper and set `MPITRAMPOLINE_LIB=/path/to/libmpiwrapper.so`.

@hawkinsp
Copybara import of the project:

--
b74bbb909d902bd30523f943a7c15f2c754cf98a by Clemens Giuliani <clemens@inailuig.it>:

add mpi collectives

--
23508eb46848464f6711dd8f3f91830ea1adb16d by Clemens Giuliani <clemens@inailuig.it>:

add explicit Init and Finalize methods and export them to python

--
bbe5840b8eb56a306a66ed03d701fd8976e01491 by Clemens Giuliani <clemens@inailuig.it>:

add comment

--
38d156282ecc89509f4b21d80db1a37cb290437a by Clemens Giuliani <clemens@inailuig.it>:

fix windows build

--
201f7238f166197ede5cf5d4d70e117a91eddcd7 by Clemens Giuliani <clemens@inailuig.it>:

fmt

--
2784869df650c1c123c346401db2f67cb153b03e by Clemens Giuliani <clemens@inailuig.it>:

bump xla_extension_version

Merging this change closes #7849

PiperOrigin-RevId: 620302264
2024-03-29 13:09:13 -07:00

19 lines
801 B
Python

"""Provides the repository macro to import mpitrampoline."""
load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")
def repo():
"""Imports mpitrampoline."""
MPITRAMPOLINE_COMMIT = "25efb0f7a4cd00ed82bafb8b1a6285fc50d297ed"
MPITRAMPOLINE_SHA256 = "5a36656205c472bdb639bffebb0f014523b32dda0c2cbedd9ce7abfc9e879e84"
tf_http_archive(
name = "mpitrampoline",
sha256 = MPITRAMPOLINE_SHA256,
strip_prefix = "MPItrampoline-{commit}".format(commit = MPITRAMPOLINE_COMMIT),
urls = tf_mirror_urls("https://github.com/eschnett/mpitrampoline/archive/{commit}.tar.gz".format(commit = MPITRAMPOLINE_COMMIT)),
patch_file = ["//third_party/mpitrampoline:gen.patch"],
build_file = "//third_party/mpitrampoline:mpitrampoline.BUILD",
)