@@ -43,8 +43,8 @@ limitations under the License.
4343#include " jaxlib/py_client.h"
4444#include " jaxlib/py_device.h"
4545#include " jaxlib/py_user_context.h"
46+ #include " xla/future.h"
4647#include " xla/hlo/ir/hlo_module.h"
47- #include " xla/pjrt/pjrt_future.h"
4848#include " xla/pjrt/pjrt_layout.h"
4949#include " xla/python/ifrt/array.h"
5050#include " xla/python/ifrt/device.h"
@@ -220,7 +220,7 @@ static ifrt::ArrayRef GetIfRtArray(const ExecuteShardedArg& arg) {
220220
221221void PopulateExecuteShardedResults (const nb_class_ptr<PyClient>& client,
222222 std::vector<ifrt::ArrayRef> ifrt_arrays,
223- const xla::PjRtFuture <>& result_status,
223+ const xla::Future <>& result_status,
224224 int num_computations,
225225 std::vector<std::vector<PyArray>>& outputs) {
226226 DCHECK_GT (num_computations, 0 );
@@ -246,11 +246,11 @@ absl::StatusOr<PyExecuteResults> ExecuteShardedOnLocalDevicesInternal(
246246 const ifrt::ExecuteOptions& options, const nb_class_ptr<PyClient>& client,
247247 ifrt::LoadedExecutable* ifrt_loaded_executable,
248248 absl::Span<const ExecuteShardedArg> args,
249- std::optional<std::vector<xla::PjRtFuture <>>>& returned_futures) {
249+ std::optional<std::vector<xla::Future <>>>& returned_futures) {
250250 std::vector<ifrt::ArrayRef> output_arrays;
251251 std::unique_ptr<tsl::Future<>> returned_future;
252252 int num_computations = ifrt_loaded_executable->addressable_devices ().size ();
253- xla::PjRtFuture <> result_status;
253+ xla::Future <> result_status;
254254 {
255255 nb::gil_scoped_release gil_release;
256256 for (const auto & arg : args) {
@@ -301,7 +301,7 @@ absl::StatusOr<PyExecuteResults> ExecuteShardedOnLocalDevicesInternal(
301301PyExecuteResults::PyExecuteResults (const nb_class_ptr<PyClient>& client,
302302 std::vector<ifrt::ArrayRef> ifrt_arrays,
303303 int num_computations, PyShardedToken token,
304- xla::PjRtFuture <> result_status)
304+ xla::Future <> result_status)
305305 : client_(client),
306306 ifrt_arrays_ (std::move(ifrt_arrays)),
307307 num_computations_(num_computations),
@@ -333,7 +333,7 @@ PyExecuteResults::DisassembleIntoSingleDeviceArrays() {
333333 std::vector<std::vector<PyArray>> outputs;
334334 PopulateExecuteShardedResults (
335335 client_, Consume (),
336- result_status_.IsValid () ? result_status_ : xla::PjRtFuture <>(),
336+ result_status_.IsValid () ? result_status_ : xla::Future <>(),
337337 num_computations_, outputs);
338338 return outputs;
339339}
@@ -357,7 +357,7 @@ PyExecuteResults::DisassemblePrefixIntoSingleDeviceArrays(size_t n) {
357357 std::vector<std::vector<PyArray>> outputs;
358358 PopulateExecuteShardedResults (
359359 client_, std::move (ifrt_arrays),
360- result_status_.IsValid () ? result_status_ : xla::PjRtFuture <>(),
360+ result_status_.IsValid () ? result_status_ : xla::Future <>(),
361361 num_computations_, outputs);
362362 return outputs;
363363}
@@ -382,7 +382,7 @@ std::vector<nb::object> PyExecuteResults::ConsumeWithHandlers(
382382 if (std::holds_alternative<const PyArrayResultHandler*>(handler)) {
383383 outputs.push_back (std::get<const PyArrayResultHandler*>(handler)->Call (
384384 client_, std::move (ifrt_arrays[buffer_id]),
385- result_status_.IsValid () ? result_status_ : xla::PjRtFuture <>()));
385+ result_status_.IsValid () ? result_status_ : xla::Future <>()));
386386 } else {
387387 tsl::profiler::TraceMe traceme (" ConsumeWithHandlers fallback." );
388388 auto disassembled_arrays =
@@ -396,7 +396,7 @@ std::vector<nb::object> PyExecuteResults::ConsumeWithHandlers(
396396 for (auto & disassembled_array : *disassembled_arrays) {
397397 nb::object array = PyArray::MakeFromSingleDeviceArray (
398398 client_, std::move (disassembled_array), false , true ,
399- result_status_.IsValid () ? result_status_ : xla::PjRtFuture <>());
399+ result_status_.IsValid () ? result_status_ : xla::Future <>());
400400 PyList_SET_ITEM (bufs.ptr (), i, array.release ().ptr ());
401401 ++i;
402402 }
@@ -417,7 +417,7 @@ absl::StatusOr<PyExecuteResults> PyLoadedExecutable::ExecuteSharded(
417417 }
418418 PyUserContextScope user_context_scope;
419419 PopulateCallLocation (options, xla::ifrt::UserContextScope::current ().get ());
420- std::optional<std::vector<xla::PjRtFuture <>>> returned_futures;
420+ std::optional<std::vector<xla::Future <>>> returned_futures;
421421 if (with_tokens) {
422422 returned_futures.emplace ();
423423 }
0 commit comments