Skip to content

Commit 376bc38

Browse files
ezhulenevGoogle-ML-Automation
authored andcommitted
[jax] Migrate jaxlib to xla::Future
PiperOrigin-RevId: 824307165
1 parent 4ef87ac commit 376bc38

File tree

8 files changed

+44
-45
lines changed

8 files changed

+44
-45
lines changed

jaxlib/BUILD

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1021,6 +1021,7 @@ cc_library(
10211021
"@tsl//tsl/platform:ml_dtypes",
10221022
"@tsl//tsl/profiler/lib:traceme",
10231023
"@xla//third_party/python_runtime:headers", # buildcleaner: keep
1024+
"@xla//xla:future",
10241025
"@xla//xla:literal",
10251026
"@xla//xla:shape_util",
10261027
"@xla//xla:status_macros",
@@ -1034,7 +1035,6 @@ cc_library(
10341035
"@xla//xla/pjrt:pjrt_client",
10351036
"@xla//xla/pjrt:pjrt_compiler",
10361037
"@xla//xla/pjrt:pjrt_executable",
1037-
"@xla//xla/pjrt:pjrt_future",
10381038
"@xla//xla/pjrt:pjrt_layout",
10391039
"@xla//xla/pjrt:status_casters",
10401040
"@xla//xla/python:nb_absl_span",
@@ -1172,9 +1172,9 @@ cc_library(
11721172
"@llvm-project//llvm:Support",
11731173
"@nanobind",
11741174
"@tsl//tsl/platform:casts",
1175+
"@xla//xla:future",
11751176
"@xla//xla:util",
11761177
"@xla//xla/pjrt:pjrt_client",
1177-
"@xla//xla/pjrt:pjrt_future",
11781178
"@xla//xla/pjrt:status_casters",
11791179
"@xla//xla/pjrt/distributed:client",
11801180
"@xla//xla/pjrt/distributed:key_value_store_interface",
@@ -1353,8 +1353,8 @@ cc_library(
13531353
"@com_google_absl//absl/time",
13541354
"@com_google_absl//absl/types:span",
13551355
"@nanobind",
1356+
"@xla//xla:future",
13561357
"@xla//xla:util",
1357-
"@xla//xla/pjrt:pjrt_future",
13581358
"@xla//xla/python:version",
13591359
"@xla//xla/python/ifrt",
13601360
"@xla//xla/tsl/concurrency:async_value",

jaxlib/py_array.cc

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -69,13 +69,13 @@ limitations under the License.
6969
#include "jaxlib/to_ifrt_sharding.h"
7070
#include "jaxlib/traceback.h"
7171
#include "jaxlib/util.h"
72+
#include "xla/future.h"
7273
#include "xla/layout.h"
7374
#include "xla/layout_util.h"
7475
#include "xla/pjrt/exceptions.h"
7576
#include "xla/pjrt/lru_cache.h"
7677
#include "xla/pjrt/pjrt_client.h"
7778
#include "xla/pjrt/pjrt_compiler.h"
78-
#include "xla/pjrt/pjrt_future.h"
7979
#include "xla/pjrt/pjrt_layout.h"
8080
#include "xla/pjrt/status_casters.h"
8181
#include "xla/primitive_util.h"
@@ -475,7 +475,7 @@ PyArray_Storage::PyArray_Storage(nb::object aval, bool weak_type,
475475
nb::object sharding, bool committed,
476476
nb_class_ptr<PyClient> py_client,
477477
ifrt::ArrayRef ifrt_array,
478-
xla::PjRtFuture<> result_status)
478+
xla::Future<> result_status)
479479
: aval(std::move(aval)),
480480
weak_type(weak_type),
481481
dtype(std::move(dtype)),
@@ -512,7 +512,7 @@ void PyInit_helper(PyArray self, nb::object aval, nb::object sharding,
512512
Construct(reinterpret_cast<PyArrayObject*>(self.ptr()), aval,
513513
nb::cast<bool>(aval.attr("weak_type")), std::move(dtype),
514514
std::move(shape), std::move(sharding), committed, py_client,
515-
std::move(ifrt_array), xla::PjRtFuture<>());
515+
std::move(ifrt_array), xla::Future<>());
516516
}
517517

518518
void PyArray::PyInit(PyArray self, nb::object aval, nb::object sharding,
@@ -533,7 +533,7 @@ void PyArray::PyInit(PyArray self, nb::object aval, nb::object sharding,
533533
PyArray PyArray::MakeFromSingleDeviceArray(nb_class_ptr<PyClient> py_client,
534534
ifrt::ArrayRef ifrt_array,
535535
bool weak_type, bool committed,
536-
xla::PjRtFuture<> result_status) {
536+
xla::Future<> result_status) {
537537
if (!llvm::isa<ifrt::SingleDeviceSharding>(ifrt_array->sharding())) {
538538
throw xla::XlaRuntimeError(xla::InvalidArgument(
539539
"Constructing single device jax.Array from non-single "
@@ -606,27 +606,27 @@ PyArray PyArrayResultHandler::Call(absl::Span<const PyArray> py_arrays) const {
606606
return Call(py_device_list.value()->py_client(),
607607
CreateIfRtArrayFromSingleDeviceShardedPyArrays(
608608
dtype_, shape_, py_arrays, sharding_),
609-
xla::PjRtFuture<>());
609+
xla::Future<>());
610610
}
611611

612612
PyArray PyArrayResultHandler::Call(nb_class_ptr<PyClient> py_client,
613613
ifrt::ArrayRef ifrt_array,
614-
xla::PjRtFuture<> result_status) const {
614+
xla::Future<> result_status) const {
615615
return PyArray(aval_, weak_type_, dtype_, shape_, sharding_,
616616
std::move(py_client), std::move(ifrt_array), committed_,
617617
skip_checks_, std::move(result_status));
618618
}
619619

620620
PyArray PyArrayResultHandler::Call(PyArray py_array) const {
621621
return Call(py_array.py_client(), tsl::FormRef(py_array.ifrt_array()),
622-
xla::PjRtFuture<>());
622+
xla::Future<>());
623623
}
624624

625625
PyArray::PyArray(nb::object aval, bool weak_type, xla::nb_dtype dtype,
626626
std::vector<int64_t> shape, nb::object sharding,
627627
nb_class_ptr<PyClient> py_client, ifrt::ArrayRef ifrt_array,
628628
bool committed, bool skip_checks,
629-
xla::PjRtFuture<> result_status) {
629+
xla::Future<> result_status) {
630630
auto* self =
631631
PyArray_tp_new(reinterpret_cast<PyTypeObject*>(type_), nullptr, nullptr);
632632
m_ptr = self;

jaxlib/py_array.h

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,9 @@ limitations under the License.
3838
#include "jaxlib/py_client.h"
3939
#include "jaxlib/py_user_context.h"
4040
#include "jaxlib/traceback.h"
41+
#include "xla/future.h"
4142
#include "xla/pjrt/exceptions.h"
4243
#include "xla/pjrt/pjrt_client.h"
43-
#include "xla/pjrt/pjrt_future.h"
4444
#include "xla/pjrt/pjrt_layout.h"
4545
#include "xla/python/ifrt/array.h"
4646
#include "xla/python/ifrt/device_list.h"
@@ -99,8 +99,7 @@ struct PyArray_Storage {
9999
PyArray_Storage(nanobind::object aval, bool weak_type, xla::nb_dtype dtype,
100100
std::vector<int64_t> shape, nanobind::object sharding,
101101
bool committed, nb_class_ptr<PyClient> py_client,
102-
xla::ifrt::ArrayRef ifrt_array,
103-
xla::PjRtFuture<> result_status);
102+
xla::ifrt::ArrayRef ifrt_array, xla::Future<> result_status);
104103

105104
~PyArray_Storage();
106105
nanobind::handle AsHandle();
@@ -125,7 +124,7 @@ struct PyArray_Storage {
125124
// Only set if this Array was generated by a computation that has effects.
126125
// This is the result status of the XLA computation that generated this
127126
// array.
128-
xla::PjRtFuture<> result_status;
127+
xla::Future<> result_status;
129128

130129
// Doubly-linked list of all PyArrays known to the client. Protected by the
131130
// GIL. Since multiple PyArrays may share the same PjRtBuffer, there may be
@@ -158,12 +157,12 @@ class PyArray : public nanobind::object {
158157
std::vector<int64_t> shape, nanobind::object sharding,
159158
nb_class_ptr<PyClient> py_client, xla::ifrt::ArrayRef ifrt_array,
160159
bool committed, bool skip_checks,
161-
xla::PjRtFuture<> result_status = xla::PjRtFuture<>());
160+
xla::Future<> result_status = xla::Future<>());
162161

163162
static PyArray MakeFromSingleDeviceArray(
164163
nb_class_ptr<PyClient> py_client, xla::ifrt::ArrayRef ifrt_array,
165164
bool weak_type, bool committed,
166-
xla::PjRtFuture<> result_status = xla::PjRtFuture<>());
165+
xla::Future<> result_status = xla::Future<>());
167166

168167
static PyArray MakeFromIfrtArrayAndSharding(nb_class_ptr<PyClient> py_client,
169168
xla::ifrt::ArrayRef ifrt_array,
@@ -226,7 +225,7 @@ class PyArray : public nanobind::object {
226225
}
227226

228227
// Returns xla::InvalidArgument if the buffer has been deleted.
229-
// See `PjRtFuture` for the semantics of `IsReady` and `IsKnownReady`.
228+
// See `Future` for the semantics of `IsReady` and `IsKnownReady`.
230229
absl::StatusOr<bool> IsReady() {
231230
xla::ifrt::Array* ifrt_array_ptr = ifrt_array();
232231
if (ifrt_array_ptr->IsDeleted()) {
@@ -237,7 +236,7 @@ class PyArray : public nanobind::object {
237236
return ifrt_array_ptr->GetReadyFuture().IsReady();
238237
}
239238

240-
const xla::PjRtFuture<>& result_status() const {
239+
const xla::Future<>& result_status() const {
241240
return GetStorage().result_status;
242241
}
243242

@@ -365,7 +364,7 @@ class PyArrayResultHandler {
365364
PyArray Call(PyArray py_array) const;
366365

367366
PyArray Call(nb_class_ptr<PyClient> py_client, xla::ifrt::ArrayRef ifrt_array,
368-
xla::PjRtFuture<> result_status = xla::PjRtFuture<>()) const;
367+
xla::Future<> result_status = xla::Future<>()) const;
369368

370369
private:
371370
nanobind::object aval_;

jaxlib/py_executable.cc

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ limitations under the License.
4343
#include "jaxlib/py_client.h"
4444
#include "jaxlib/py_device.h"
4545
#include "jaxlib/py_user_context.h"
46+
#include "xla/future.h"
4647
#include "xla/hlo/ir/hlo_module.h"
47-
#include "xla/pjrt/pjrt_future.h"
4848
#include "xla/pjrt/pjrt_layout.h"
4949
#include "xla/python/ifrt/array.h"
5050
#include "xla/python/ifrt/device.h"
@@ -220,7 +220,7 @@ static ifrt::ArrayRef GetIfRtArray(const ExecuteShardedArg& arg) {
220220

221221
void PopulateExecuteShardedResults(const nb_class_ptr<PyClient>& client,
222222
std::vector<ifrt::ArrayRef> ifrt_arrays,
223-
const xla::PjRtFuture<>& result_status,
223+
const xla::Future<>& result_status,
224224
int num_computations,
225225
std::vector<std::vector<PyArray>>& outputs) {
226226
DCHECK_GT(num_computations, 0);
@@ -246,11 +246,11 @@ absl::StatusOr<PyExecuteResults> ExecuteShardedOnLocalDevicesInternal(
246246
const ifrt::ExecuteOptions& options, const nb_class_ptr<PyClient>& client,
247247
ifrt::LoadedExecutable* ifrt_loaded_executable,
248248
absl::Span<const ExecuteShardedArg> args,
249-
std::optional<std::vector<xla::PjRtFuture<>>>& returned_futures) {
249+
std::optional<std::vector<xla::Future<>>>& returned_futures) {
250250
std::vector<ifrt::ArrayRef> output_arrays;
251251
std::unique_ptr<tsl::Future<>> returned_future;
252252
int num_computations = ifrt_loaded_executable->addressable_devices().size();
253-
xla::PjRtFuture<> result_status;
253+
xla::Future<> result_status;
254254
{
255255
nb::gil_scoped_release gil_release;
256256
for (const auto& arg : args) {
@@ -301,7 +301,7 @@ absl::StatusOr<PyExecuteResults> ExecuteShardedOnLocalDevicesInternal(
301301
PyExecuteResults::PyExecuteResults(const nb_class_ptr<PyClient>& client,
302302
std::vector<ifrt::ArrayRef> ifrt_arrays,
303303
int num_computations, PyShardedToken token,
304-
xla::PjRtFuture<> result_status)
304+
xla::Future<> result_status)
305305
: client_(client),
306306
ifrt_arrays_(std::move(ifrt_arrays)),
307307
num_computations_(num_computations),
@@ -333,7 +333,7 @@ PyExecuteResults::DisassembleIntoSingleDeviceArrays() {
333333
std::vector<std::vector<PyArray>> outputs;
334334
PopulateExecuteShardedResults(
335335
client_, Consume(),
336-
result_status_.IsValid() ? result_status_ : xla::PjRtFuture<>(),
336+
result_status_.IsValid() ? result_status_ : xla::Future<>(),
337337
num_computations_, outputs);
338338
return outputs;
339339
}
@@ -357,7 +357,7 @@ PyExecuteResults::DisassemblePrefixIntoSingleDeviceArrays(size_t n) {
357357
std::vector<std::vector<PyArray>> outputs;
358358
PopulateExecuteShardedResults(
359359
client_, std::move(ifrt_arrays),
360-
result_status_.IsValid() ? result_status_ : xla::PjRtFuture<>(),
360+
result_status_.IsValid() ? result_status_ : xla::Future<>(),
361361
num_computations_, outputs);
362362
return outputs;
363363
}
@@ -382,7 +382,7 @@ std::vector<nb::object> PyExecuteResults::ConsumeWithHandlers(
382382
if (std::holds_alternative<const PyArrayResultHandler*>(handler)) {
383383
outputs.push_back(std::get<const PyArrayResultHandler*>(handler)->Call(
384384
client_, std::move(ifrt_arrays[buffer_id]),
385-
result_status_.IsValid() ? result_status_ : xla::PjRtFuture<>()));
385+
result_status_.IsValid() ? result_status_ : xla::Future<>()));
386386
} else {
387387
tsl::profiler::TraceMe traceme("ConsumeWithHandlers fallback.");
388388
auto disassembled_arrays =
@@ -396,7 +396,7 @@ std::vector<nb::object> PyExecuteResults::ConsumeWithHandlers(
396396
for (auto& disassembled_array : *disassembled_arrays) {
397397
nb::object array = PyArray::MakeFromSingleDeviceArray(
398398
client_, std::move(disassembled_array), false, true,
399-
result_status_.IsValid() ? result_status_ : xla::PjRtFuture<>());
399+
result_status_.IsValid() ? result_status_ : xla::Future<>());
400400
PyList_SET_ITEM(bufs.ptr(), i, array.release().ptr());
401401
++i;
402402
}
@@ -417,7 +417,7 @@ absl::StatusOr<PyExecuteResults> PyLoadedExecutable::ExecuteSharded(
417417
}
418418
PyUserContextScope user_context_scope;
419419
PopulateCallLocation(options, xla::ifrt::UserContextScope::current().get());
420-
std::optional<std::vector<xla::PjRtFuture<>>> returned_futures;
420+
std::optional<std::vector<xla::Future<>>> returned_futures;
421421
if (with_tokens) {
422422
returned_futures.emplace();
423423
}

jaxlib/py_executable.h

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,11 @@ limitations under the License.
3636
#include "jaxlib/py_array.h"
3737
#include "jaxlib/py_client.h"
3838
#include "jaxlib/traceback.h"
39+
#include "xla/future.h"
3940
#include "xla/hlo/ir/hlo_module.h"
4041
#include "xla/pjrt/exceptions.h"
4142
#include "xla/pjrt/pjrt_client.h"
4243
#include "xla/pjrt/pjrt_executable.h"
43-
#include "xla/pjrt/pjrt_future.h"
4444
#include "xla/pjrt/pjrt_layout.h"
4545
#include "xla/python/ifrt/array.h"
4646
#include "xla/python/ifrt/attribute_map.h"
@@ -53,24 +53,24 @@ namespace jax {
5353
class PyToken {
5454
public:
5555
PyToken() = default;
56-
explicit PyToken(xla::PjRtFuture<> future) : future_(std::move(future)) {}
56+
explicit PyToken(xla::Future<> future) : future_(std::move(future)) {}
5757

5858
static PyToken ReadyPyToken() {
59-
return PyToken(xla::PjRtFuture<>(absl::OkStatus()));
59+
return PyToken(xla::Future<>(absl::OkStatus()));
6060
}
6161

6262
absl::Status Await();
6363

6464
private:
65-
xla::PjRtFuture<> future_;
65+
xla::Future<> future_;
6666
};
6767

6868
// PyShardedToken contains a PyToken for each device's execution.
6969
class PyShardedToken {
7070
public:
7171
// Default construction creates a always-ready token.
7272
PyShardedToken() = default;
73-
explicit PyShardedToken(std::vector<xla::PjRtFuture<>> futures)
73+
explicit PyShardedToken(std::vector<xla::Future<>> futures)
7474
: futures_(std::move(futures)) {}
7575

7676
PyToken GetPyToken(int device_id) const {
@@ -81,15 +81,15 @@ class PyShardedToken {
8181
absl::Status Await();
8282

8383
private:
84-
std::vector<xla::PjRtFuture<>> futures_;
84+
std::vector<xla::Future<>> futures_;
8585
};
8686

8787
class PyExecuteResults {
8888
public:
8989
PyExecuteResults(const nb_class_ptr<PyClient>& client,
9090
std::vector<xla::ifrt::ArrayRef> ifrt_arrays,
9191
int num_computations, PyShardedToken token,
92-
xla::PjRtFuture<> result_status = xla::PjRtFuture<>());
92+
xla::Future<> result_status = xla::Future<>());
9393

9494
std::vector<std::vector<PyArray>> DisassembleIntoSingleDeviceArrays();
9595

@@ -119,7 +119,7 @@ class PyExecuteResults {
119119
int num_computations_;
120120
PyShardedToken token_;
121121
// Only set if the computation has tokens.
122-
xla::PjRtFuture<> result_status_;
122+
xla::Future<> result_status_;
123123
};
124124

125125
using ExecuteShardedArg = std::variant<PyArray, std::vector<PyArray>>;

jaxlib/py_socket_transfer.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,10 @@ limitations under the License.
4444
#include "jaxlib/py_executable.h"
4545
#include "jaxlib/py_user_context.h"
4646
#include "jaxlib/to_ifrt_sharding.h"
47+
#include "xla/future.h"
4748
#include "xla/pjrt/distributed/client.h"
4849
#include "xla/pjrt/distributed/key_value_store_interface.h"
4950
#include "xla/pjrt/pjrt_client.h"
50-
#include "xla/pjrt/pjrt_future.h"
5151
#include "xla/pjrt/status_casters.h"
5252
#include "xla/python/ifrt/array.h"
5353
#include "xla/python/ifrt/array_spec.h"
@@ -424,12 +424,12 @@ void RegisterTransferServerTypes(nanobind::module_& m) {
424424
nb::repr(slice).c_str(), device_size)
425425
.c_str());
426426
}
427-
std::vector<xla::PjRtFuture<>> futures_per_array;
427+
std::vector<xla::Future<>> futures_per_array;
428428
for (auto& buffer : arrs[i]->pjrt_buffers()) {
429429
auto raw_buffer = xla::ValueOrThrow(
430430
xla::PjRtRawBuffer::CreateRawAliasOfBuffer(buffer.get()));
431431
tsl::RCReference<ChunkDestination> dest;
432-
xla::PjRtFuture<> future;
432+
xla::Future<> future;
433433
std::tie(dest, future) = xla::ValueOrThrow(
434434
CreateSlicedRawBufferDest(raw_buffer, start, total_size));
435435
futures_per_array.push_back(std::move(future));

jaxlib/util.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ limitations under the License.
2323
#include "absl/time/time.h"
2424
#include "absl/types/span.h"
2525
#include "nanobind/nanobind.h"
26-
#include "xla/pjrt/pjrt_future.h"
26+
#include "xla/future.h"
2727
#include "xla/python/ifrt/array.h"
2828
#include "xla/python/ifrt/client.h"
2929
#include "xla/python/ifrt/value.h"
@@ -37,7 +37,7 @@ namespace ifrt = xla::ifrt;
3737

3838
namespace jax {
3939

40-
void BlockUntilReadyWithCancel(xla::PjRtFuture<>& future) {
40+
void BlockUntilReadyWithCancel(xla::Future<>& future) {
4141
future.BlockUntilReady([](tsl::AsyncValue* value) {
4242
auto state = std::make_shared<absl::Notification>();
4343
value->AndThen([state]() { state->Notify(); });

jaxlib/util.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,13 @@ limitations under the License.
1818

1919
#include "absl/status/status.h"
2020
#include "absl/types/span.h"
21-
#include "xla/pjrt/pjrt_future.h"
21+
#include "xla/future.h"
2222
#include "xla/python/ifrt/array.h"
2323

2424
namespace jax {
2525

2626
// Waits until future is ready but will cancel if ctrl-c is pressed.
27-
void BlockUntilReadyWithCancel(xla::PjRtFuture<>& future);
27+
void BlockUntilReadyWithCancel(xla::Future<>& future);
2828

2929
// Requests if given buffers are ready, awaits for results and returns OK if
3030
// all of the buffers are ready or the last non-ok status.

0 commit comments

Comments
 (0)