Skip to content

Commit 5762ebf

Browse files
authored
Create TT client via plugin for TensTorrent devices (#1860)
* Create TT client via plugin for TensTorrent devices * Automatically download wheel for TT PJRT plugin * Remove self-qualified accesses * [tt plugin] Implement logic to detect Tenstorrent devices * [tt plugin] More comments * Update URL of TT PJRT plugin * [docs] Mention Tenstorrent as experimental backend
1 parent cca938c commit 5762ebf

File tree

6 files changed

+151
-0
lines changed

6 files changed

+151
-0
lines changed

docs/src/index.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,4 +81,8 @@ using Reactant
8181
Reactant.set_default_backend("tpu")
8282
```
8383

84+
```julia [Tenstorrent (Experimental)]
85+
using Reactant
86+
Reactant.set_default_backend("tt")
87+
```
8488
:::

src/accelerators/Accelerators.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,6 @@ module Accelerators
22

33
include("TPU.jl")
44
include("Metal.jl")
5+
include("TT.jl")
56

67
end

src/accelerators/TT.jl

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
module TT
2+
3+
using Reactant: Reactant
4+
using Scratch: @get_scratch!
5+
using Downloads: Downloads
6+
using p7zip_jll: p7zip
7+
8+
const tt_pjrt_plugin_dir = Ref{Union{Nothing,String}}(nothing)
9+
const tt_pjrt_plugin_name = Ref{String}("pjrt_plugin_tt.so")
10+
11+
function __init__()
12+
@static if Sys.islinux()
13+
if !Reactant.precompiling() && has_tt()
14+
setup_tt_pjrt_plugin!()
15+
end
16+
end
17+
end
18+
19+
force_tt_init() = haskey(ENV, "REACTANT_FORCE_TT_INIT")
20+
21+
function has_tt()
22+
if force_tt_init()
23+
return true
24+
end
25+
26+
# To find whether we have Tenstorrent devices, we can either
27+
#
28+
# * look for devices in `/dev/tenstorrent`, or
29+
# * look for devices in `/sys/bus/pci/devices` with `vendor` equal to `0x1e52`, something like
30+
# any(readchomp(joinpath(dir, "vendor")) == "0x1e52" for dir in readdir("/sys/bus/pci/devices"; join=true))
31+
#
32+
# The former is simpler for our current purposes, so we can go that way.
33+
dev_tt = "/dev/tenstorrent"
34+
return isdir(dev_tt) && length(readdir(dev_tt)) > 0
35+
end
36+
37+
function setup_tt_pjrt_plugin!()
38+
plugin_dir_from_env = get(ENV, "TT_PJRT_PLUGIN_DIR", nothing)
39+
if plugin_dir_from_env !== nothing && ispath(plugin_dir_from_env)
40+
tt_pjrt_plugin_dir[] = plugin_dir_from_env
41+
else
42+
tt_pjrt_plugin_dir[] = @get_scratch!("pjrt_plugin_tt")
43+
end
44+
download_tt_pjrt_plugin_if_needed(tt_pjrt_plugin_dir[])
45+
return nothing
46+
end
47+
48+
get_tt_pjrt_plugin_dir() = tt_pjrt_plugin_dir[]
49+
50+
function get_tt_pjrt_plugin_path()
51+
return joinpath(get_tt_pjrt_plugin_dir(), tt_pjrt_plugin_name[])
52+
end
53+
54+
function download_tt_pjrt_plugin_if_needed(dir=nothing)
55+
dir === nothing && (dir = get_tt_pjrt_plugin_dir())
56+
@assert dir !== nothing "tt_pjrt_plugin_dir is not set!"
57+
58+
tt_pjrt_plugin_path = joinpath(dir, tt_pjrt_plugin_name[])
59+
if isfile(tt_pjrt_plugin_path)
60+
@debug "TT PJRT plugin already found in '$(tt_pjrt_plugin_path)', nothing to do"
61+
else
62+
@debug "Will install the TT PJRT plugin to '$(tt_pjrt_plugin_path)'"
63+
mktempdir() do tmp_dir
64+
# Index at https://pypi.eng.aws.tenstorrent.com/pjrt-plugin-tt/
65+
zip_file_path = joinpath(tmp_dir, "pjrt-plugin-tt.zip")
66+
wheel_url = if Sys.ARCH === :x86_64
67+
"https://pypi.eng.aws.tenstorrent.com/pjrt-plugin-tt/pjrt_plugin_tt-0.6.0.dev20251202-cp311-cp311-linux_x86_64.whl"
68+
else
69+
error("Unsupported architecture for TT PJRT plugin: $(Sys.ARCH)")
70+
end
71+
@debug "Downloading TT PJRT plugin from '$(wheel_url)'"
72+
Downloads.download(wheel_url, zip_file_path)
73+
run(pipeline(`$(p7zip()) x -tzip -o$(tmp_dir) -- $(zip_file_path)`, devnull))
74+
data_dir = only(filter!(endswith(".data"), readdir(tmp_dir; join=true)))
75+
# We need to move the entire `pjrt_plugin_tt` directory to the destination.
76+
mv(joinpath(data_dir, "purelib", "pjrt_plugin_tt"), dir; force=true)
77+
end
78+
@assert isfile(tt_pjrt_plugin_path)
79+
end
80+
end
81+
82+
end # module TT

src/xla/IFRT/Client.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,12 +115,14 @@ const cpu_client_count = Ref(0)
115115
const cuda_client_count = Ref(0)
116116
const tpu_client_count = Ref(0)
117117
const metal_client_count = Ref(0)
118+
const tt_client_count = Ref(0)
118119

119120
for (backend, counter) in (
120121
(:CPUClient, :cpu_client_count),
121122
(:CUDAClient, :cuda_client_count),
122123
(:TPUClient, :tpu_client_count),
123124
(:MetalClient, :metal_client_count),
125+
(:TTClient, :tt_client_count),
124126
)
125127
main_fn = Symbol(:MakeIFRTPJRT, backend)
126128
@eval function $(backend)(args...; checkcount::Bool=true, kwargs...)
@@ -219,6 +221,17 @@ function MakeIFRTPJRTMetalClient(;
219221
)
220222
end
221223

224+
function MakeIFRTPJRTTTClient(;
225+
tt_pjrt_plugin_path::String,
226+
node_id::Integer=0,
227+
num_nodes::Integer=1,
228+
distributed_runtime_client::Union{Nothing,XLA.DistributedRuntimeClient}=nothing,
229+
)
230+
return MakeIFRTPJRTClientViaPluginAPI(
231+
tt_pjrt_plugin_path, "tt", "TT"; node_id, num_nodes, distributed_runtime_client
232+
)
233+
end
234+
222235
function MakeIFRTPJRTClientViaPluginAPI(
223236
library_path::String,
224237
device_type::String,

src/xla/PJRT/Client.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,12 +110,14 @@ const cpu_client_count = Ref(0)
110110
const cuda_client_count = Ref(0)
111111
const tpu_client_count = Ref(0)
112112
const metal_client_count = Ref(0)
113+
const tt_client_count = Ref(0)
113114

114115
for (backend, counter) in (
115116
(:CPUClient, :cpu_client_count),
116117
(:CUDAClient, :cuda_client_count),
117118
(:TPUClient, :tpu_client_count),
118119
(:MetalClient, :metal_client_count),
120+
(:TTClient, :tt_client_count),
119121
)
120122
main_fn = Symbol(:Make, backend)
121123
@eval function $(backend)(args...; checkcount::Bool=true, kwargs...)
@@ -207,6 +209,20 @@ function MakeMetalClient(;
207209
return MakeClientUsingPluginAPI(metal_pjrt_plugin_path, "metal", "METAL")
208210
end
209211

212+
function MakeTTClient(;
213+
tt_pjrt_plugin_path::String,
214+
node_id::Integer=0,
215+
num_nodes::Integer=1,
216+
distributed_runtime_client::Union{Nothing,XLA.DistributedRuntimeClient}=nothing,
217+
)
218+
@assert node_id == 0 "`PJRT.MakeTTClient` does not support node_id"
219+
@assert num_nodes == 1 "`PJRT.MakeTTClient` does not support num_nodes > 1"
220+
@assert distributed_runtime_client === nothing "`PJRT.MakeTTClient` does not support \
221+
distributed_runtime_client"
222+
223+
return MakeClientUsingPluginAPI(tt_pjrt_plugin_path, "tt", "TT")
224+
end
225+
210226
function MakeClientUsingPluginAPI(
211227
library_path::String, device_type::String, client_name::String=uppercase(device_type)
212228
)

src/xla/XLA.jl

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,41 @@ for runtime in (:PJRT, :IFRT)
226226
catch e
227227
println(stdout, e)
228228
end
229+
elseif Accelerators.TT.has_tt()
230+
@debug "TT accelerator detected, setting it up"
231+
try
232+
if was_initialized && haskey(state.clients, "tt")
233+
free_client(state.clients["tt"])
234+
$(runtime).tt_client_count[] -= 1
235+
end
236+
# The env var `TT_METAL_RUNTIME_ROOT` must be set before creating the client.
237+
tt_metal_runtime_root = get(ENV, "TT_METAL_RUNTIME_ROOT", nothing)
238+
if isnothing(tt_metal_runtime_root)
239+
tt_metal_path_in_wheel = joinpath(
240+
dirname(Accelerators.TT.get_tt_pjrt_plugin_path()),
241+
"tt-metal",
242+
)
243+
if ispath(tt_metal_path_in_wheel)
244+
@debug "Setting environment variable 'TT_METAL_RUNTIME_ROOT' to '$(tt_metal_path_in_wheel)'"
245+
ENV["TT_METAL_RUNTIME_ROOT"] = tt_metal_path_in_wheel
246+
else
247+
error(
248+
"`TT_METAL_RUNTIME_ROOT` environment variable not set and we could not automatically determine it",
249+
)
250+
end
251+
else
252+
@debug "Environment variable 'TT_METAL_RUNTIME_ROOT' already set to to '$(tt_metal_runtime_root)'"
253+
end
254+
255+
tt = $(runtime).TTClient(;
256+
tt_pjrt_plugin_path=Accelerators.TT.get_tt_pjrt_plugin_path(),
257+
common_kwargs...,
258+
)
259+
state.clients["tt"] = tt
260+
state.default_client = tt
261+
catch e
262+
println(stdout, e)
263+
end
229264
elseif Reactant_jll.host_platform.tags["gpu"] != "none"
230265
try
231266
if was_initialized && haskey(state.clients, "cuda")

0 commit comments

Comments
 (0)