|
| 1 | +module TT |
| 2 | + |
| 3 | +using Reactant: Reactant |
| 4 | +using Scratch: @get_scratch! |
| 5 | +using Downloads: Downloads |
| 6 | +using p7zip_jll: p7zip |
| 7 | + |
| 8 | +const tt_pjrt_plugin_dir = Ref{Union{Nothing,String}}(nothing) |
| 9 | +const tt_pjrt_plugin_name = Ref{String}("pjrt_plugin_tt.so") |
| 10 | + |
| 11 | +function __init__() |
| 12 | + @static if Sys.islinux() |
| 13 | + if !Reactant.precompiling() && has_tt() |
| 14 | + setup_tt_pjrt_plugin!() |
| 15 | + end |
| 16 | + end |
| 17 | +end |
| 18 | + |
| 19 | +force_tt_init() = haskey(ENV, "REACTANT_FORCE_TT_INIT") |
| 20 | + |
| 21 | +function has_tt() |
| 22 | + if force_tt_init() |
| 23 | + return true |
| 24 | + end |
| 25 | + |
| 26 | + # To find whether we have Tenstorrent devices, we can either |
| 27 | + # |
| 28 | + # * look for devices in `/dev/tenstorrent`, or |
| 29 | + # * look for devices in `/sys/bus/pci/devices` with `vendor` equal to `0x1e52`, something like |
| 30 | + # any(readchomp(joinpath(dir, "vendor")) == "0x1e52" for dir in readdir("/sys/bus/pci/devices"; join=true)) |
| 31 | + # |
| 32 | + # The former is simpler for our current purposes, so we can go that way. |
| 33 | + dev_tt = "/dev/tenstorrent" |
| 34 | + return isdir(dev_tt) && length(readdir(dev_tt)) > 0 |
| 35 | +end |
| 36 | + |
| 37 | +function setup_tt_pjrt_plugin!() |
| 38 | + plugin_dir_from_env = get(ENV, "TT_PJRT_PLUGIN_DIR", nothing) |
| 39 | + if plugin_dir_from_env !== nothing && ispath(plugin_dir_from_env) |
| 40 | + tt_pjrt_plugin_dir[] = plugin_dir_from_env |
| 41 | + else |
| 42 | + tt_pjrt_plugin_dir[] = @get_scratch!("pjrt_plugin_tt") |
| 43 | + end |
| 44 | + download_tt_pjrt_plugin_if_needed(tt_pjrt_plugin_dir[]) |
| 45 | + return nothing |
| 46 | +end |
| 47 | + |
| 48 | +get_tt_pjrt_plugin_dir() = tt_pjrt_plugin_dir[] |
| 49 | + |
| 50 | +function get_tt_pjrt_plugin_path() |
| 51 | + return joinpath(get_tt_pjrt_plugin_dir(), tt_pjrt_plugin_name[]) |
| 52 | +end |
| 53 | + |
| 54 | +function download_tt_pjrt_plugin_if_needed(dir=nothing) |
| 55 | + dir === nothing && (dir = get_tt_pjrt_plugin_dir()) |
| 56 | + @assert dir !== nothing "tt_pjrt_plugin_dir is not set!" |
| 57 | + |
| 58 | + tt_pjrt_plugin_path = joinpath(dir, tt_pjrt_plugin_name[]) |
| 59 | + if isfile(tt_pjrt_plugin_path) |
| 60 | + @debug "TT PJRT plugin already found in '$(tt_pjrt_plugin_path)', nothing to do" |
| 61 | + else |
| 62 | + @debug "Will install the TT PJRT plugin to '$(tt_pjrt_plugin_path)'" |
| 63 | + mktempdir() do tmp_dir |
| 64 | + # Index at https://pypi.eng.aws.tenstorrent.com/pjrt-plugin-tt/ |
| 65 | + zip_file_path = joinpath(tmp_dir, "pjrt-plugin-tt.zip") |
| 66 | + wheel_url = if Sys.ARCH === :x86_64 |
| 67 | + "https://pypi.eng.aws.tenstorrent.com/pjrt-plugin-tt/pjrt_plugin_tt-0.6.0.dev20251202-cp311-cp311-linux_x86_64.whl" |
| 68 | + else |
| 69 | + error("Unsupported architecture for TT PJRT plugin: $(Sys.ARCH)") |
| 70 | + end |
| 71 | + @debug "Downloading TT PJRT plugin from '$(wheel_url)'" |
| 72 | + Downloads.download(wheel_url, zip_file_path) |
| 73 | + run(pipeline(`$(p7zip()) x -tzip -o$(tmp_dir) -- $(zip_file_path)`, devnull)) |
| 74 | + data_dir = only(filter!(endswith(".data"), readdir(tmp_dir; join=true))) |
| 75 | + # We need to move the entire `pjrt_plugin_tt` directory to the destination. |
| 76 | + mv(joinpath(data_dir, "purelib", "pjrt_plugin_tt"), dir; force=true) |
| 77 | + end |
| 78 | + @assert isfile(tt_pjrt_plugin_path) |
| 79 | + end |
| 80 | +end |
| 81 | + |
| 82 | +end # module TT |
0 commit comments