Skip to content

Commit

Permalink
[mosaic_gpu] Include Mosaic GPU dialect fiels into jaxlib
Browse files Browse the repository at this point in the history
  • Loading branch information
superbobry committed Dec 23, 2024
1 parent 3e7f481 commit 8987867
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 0 deletions.
1 change: 1 addition & 0 deletions jaxlib/mosaic/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ package(
py_library(
name = "mosaic",
deps = [
"//jaxlib/mosaic/python:gpu_dialect",
"//jaxlib/mosaic/python:tpu_dialect",
],
)
Expand Down
1 change: 1 addition & 0 deletions jaxlib/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def has_ext_modules(self):
'cuda/*',
'cuda/nvvm/libdevice/libdevice*',
'mosaic/*.py',
'mosaic/dialect/gpu/*.py',
'mosaic/gpu/*.so',
'mosaic/python/*.py',
'mosaic/python/*.so',
Expand Down
12 changes: 12 additions & 0 deletions jaxlib/tools/build_wheel.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,13 +218,24 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu):
dst_dir=mosaic_python_dir,
src_files=[
"__main__/jaxlib/mosaic/python/layout_defs.py",
"__main__/jaxlib/mosaic/python/mosaic_gpu.py",
"__main__/jaxlib/mosaic/python/tpu.py",
],
)
# TODO (sharadmv,skyewm): can we avoid patching this file?
patch_copy_mlir_import(
"__main__/jaxlib/mosaic/python/_tpu_gen.py", dst_dir=mosaic_python_dir
)
mosaic_gpu_dir = jaxlib_dir / "mosaic" / "dialect" / "gpu"
os.makedirs(mosaic_gpu_dir)
patch_copy_mlir_import(
"__main__/jaxlib/mosaic/dialect/gpu/_mosaic_gpu_gen_ops.py",
dst_dir=mosaic_gpu_dir,
)
patch_copy_mlir_import(
"__main__/jaxlib/mosaic/dialect/gpu/_mosaic_gpu_gen_enums.py",
dst_dir=mosaic_gpu_dir,
)

copy_runfiles(
dst_dir=jaxlib_dir / "mlir",
Expand Down Expand Up @@ -316,6 +327,7 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu):
f"__main__/jaxlib/mlir/_mlir_libs/_mlirHlo.{pyext}",
f"__main__/jaxlib/mlir/_mlir_libs/_mlirDialectsSparseTensor.{pyext}",
f"__main__/jaxlib/mlir/_mlir_libs/_mlirSparseTensorPasses.{pyext}",
f"__main__/jaxlib/mlir/_mlir_libs/_mosaic_gpu_ext.{pyext}",
f"__main__/jaxlib/mlir/_mlir_libs/_tpu_ext.{pyext}",
f"__main__/jaxlib/mlir/_mlir_libs/_sdy.{pyext}",
f"__main__/jaxlib/mlir/_mlir_libs/_stablehlo.{pyext}",
Expand Down

0 comments on commit 8987867

Please sign in to comment.