Skip to content

Commit 07719f9

Browse files
committed
more copy back
1 parent 950ca8f commit 07719f9

File tree

1 file changed

+14
-13
lines changed

1 file changed

+14
-13
lines changed

deps/ReactantExtra/API.cpp

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1062,6 +1062,20 @@ REACTANT_ABI void CopyToBuffer(PjRtClient *client, PjRtBuffer *buffer,
10621062
#endif
10631063
}
10641064

1065+
REACTANT_ABI void BufferToHost(PjRtBuffer *buffer, void *data) {
1066+
Shape shape(MyValueOrThrow(buffer->HostShape()));
1067+
/// Grumpily the cpu copy code does not respect layout and does a raw copy
1068+
/// For now, we assume a non-julia row major ordering
1069+
/// If in the future it supports col_major we can swap to that.
1070+
*shape.mutable_layout() = xla::Layout(row_major(shape.dimensions_size()));
1071+
MutableBorrowingLiteral literal((const char *)data, shape);
1072+
auto status = buffer->ToLiteralSync(&literal);
1073+
if (!status.ok()) {
1074+
printf("error copying to host: %s\n", status.ToString().c_str());
1075+
}
1076+
}
1077+
1078+
10651079
REACTANT_ABI void CopyFromBuffer(PjRtClient *client, PjRtBuffer *buffer,
10661080
void *data, size_t offset, size_t size, PjRtBuffer **bufferP) {
10671081

@@ -1130,19 +1144,6 @@ REACTANT_ABI PjRtBuffer *CopyBufferToDevice(PjRtBuffer *buffer,
11301144
return res.release();
11311145
}
11321146

1133-
REACTANT_ABI void BufferToHost(PjRtBuffer *buffer, void *data) {
1134-
Shape shape(MyValueOrThrow(buffer->HostShape()));
1135-
/// Grumpily the cpu copy code does not respect layout and does a raw copy
1136-
/// For now, we assume a non-julia row major ordering
1137-
/// If in the future it supports col_major we can swap to that.
1138-
*shape.mutable_layout() = xla::Layout(row_major(shape.dimensions_size()));
1139-
MutableBorrowingLiteral literal((const char *)data, shape);
1140-
auto status = buffer->ToLiteralSync(&literal);
1141-
if (!status.ok()) {
1142-
printf("error copying to host: %s\n", status.ToString().c_str());
1143-
}
1144-
}
1145-
11461147
REACTANT_ABI void FreeClient(PjRtClient *client) { delete client; }
11471148

11481149
REACTANT_ABI int64_t PjRtDeviceGetLocalDeviceId(PjRtDevice *device) {

0 commit comments

Comments
 (0)