tensorflow/third_party/pybind11_bazel/pybind11_bazel.patch
Antonio Sanchez f366199c00 Add ml_dtypes dependency to TSL/TensorFlow.
All custom types shared between TF/TSL/JAX are upstreamed to the
ml_dtypes package.  This change adds that dependency.  A follow-up
will replace TF's existing definitions with the upstreamed ones.

Moved the `pybind11_bazel` dependency since it is used both by
`ml_dtypes` and `pybind11_abseil`.

PiperOrigin-RevId: 539157226
2023-06-09 13:05:54 -07:00

38 lines
1.1 KiB
Diff

diff --git a/build_defs.bzl b/build_defs.bzl
index cde1e93..03f14a5 100644
--- a/build_defs.bzl
+++ b/build_defs.bzl
@@ -27,7 +27,9 @@ PYBIND_DEPS = [
# Builds a Python extension module using pybind11.
# This can be directly used in python with the import statement.
-# This adds rules for a .so binary file.
+# This adds rules for .so and .pyd binary files, as well as
+# a base target that selects between them depending on the platform
+# (.pyd for windows, .so otherwise).
def pybind_extension(
name,
copts = [],
@@ -59,6 +61,21 @@ def pybind_extension(
**kwargs
)
+ native.genrule(
+ name = name + "_pyd",
+ srcs = [name + ".so"],
+ outs = [name + ".pyd"],
+ cmd = "cp $< $@",
+ )
+
+ native.py_library(
+ name = name,
+ data = select({
+ "@platforms//os:windows": [":" + name + ".pyd"],
+ "//conditions:default": [":" + name + ".so"],
+ }),
+ )
+
# Builds a pybind11 compatible library. This can be linked to a pybind_extension.
def pybind_library(
name,