Skip to content

Commit 0695d6b

Browse files
committed
Add TPU C api copytodevice
1 parent 409fdf7 commit 0695d6b

File tree

1 file changed

+16
-7
lines changed

1 file changed

+16
-7
lines changed

deps/ReactantExtra/API.cpp

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -986,7 +986,7 @@ REACTANT_ABI void *UnsafeBufferPointer(PjRtBuffer *buffer) {
986986
}
987987

988988
REACTANT_ABI void CopyToBuffer(PjRtClient *client, PjRtBuffer *buffer,
989-
void *data, size_t offset, size_t size) {
989+
void *data, size_t offset, size_t size, PjRtBuffer **bufferP) {
990990
if (buffer->IsOnCpu()) {
991991
auto unsafe =
992992
(char *)MyValueOrThrow(buffer->client()->UnsafeBufferPointer(buffer));
@@ -995,6 +995,16 @@ REACTANT_ABI void CopyToBuffer(PjRtClient *client, PjRtBuffer *buffer,
995995
// data, size);
996996
return;
997997
}
998+
999+
auto pid = client->platform_id();
1000+
if (pid == xla::TpuId()) {
1001+
auto dims = argB->on_device_shape().dimensions();
1002+
auto buf2 = ArrayFromHostBuffer(client, data, buffer->element_type(), dims.size(), dims.data(), buffer->device());
1003+
*bufferP = buf2;
1004+
PjRtBufferFree((PjRtBuffer *)buffer);
1005+
return;
1006+
}
1007+
9981008
auto raw_buffer =
9991009
MyValueOrThrow(PjRtRawBuffer::CreateRawAliasOfBuffer(buffer));
10001010
auto future = raw_buffer->CopyRawHostToDevice(data, offset, size);
@@ -1005,7 +1015,6 @@ REACTANT_ABI void CopyToBuffer(PjRtClient *client, PjRtBuffer *buffer,
10051015
return;
10061016
}
10071017

1008-
auto pid = client->platform_id();
10091018
if (pid == xla::CudaId()) {
10101019
auto stream_client = (xla::PjRtStreamExecutorClient*)lrt->client;
10111020

@@ -1032,7 +1041,7 @@ REACTANT_ABI void CopyToBuffer(PjRtClient *client, PjRtBuffer *buffer,
10321041
}
10331042

10341043
REACTANT_ABI void CopyFromBuffer(PjRtClient *client, PjRtBuffer *buffer,
1035-
void *data, size_t offset, size_t size) {
1044+
void *data, size_t offset, size_t size, PjRtBuffer **bufferP) {
10361045
auto future = buffer->CopyRawToHost(data, offset, size);
10371046
future.Await();
10381047
#if 0
@@ -3147,14 +3156,14 @@ REACTANT_ABI void reactantXLAMemcpy(LinkableRuntime **__restrict__ lrtP,
31473156
break;
31483157
case 1: // cudaMemcpyHostToDevice
31493158
{
3150-
auto &&[dstB, dstO, _] = bufferAndOffset(lrt, dst);
3151-
CopyToBuffer(lrt->client, dstB, src, dstO, size);
3159+
auto &&[dstB, dstO, start] = bufferAndOffset(lrt, dst);
3160+
CopyToBuffer(lrt->client, dstB, src, dstO, size, start);
31523161
break;
31533162
}
31543163
case 2: // cudaMemcpyDeviceToHost
31553164
{
3156-
auto &&[srcB, srcO, _] = bufferAndOffset(lrt, src);
3157-
CopyFromBuffer(lrt->client, srcB, dst, srcO, size);
3165+
auto &&[srcB, srcO, start] = bufferAndOffset(lrt, src);
3166+
CopyFromBuffer(lrt->client, srcB, dst, srcO, size, start);
31583167
break;
31593168
}
31603169
case 3: // cudaMemcpyDeviceToDevice

0 commit comments

Comments
 (0)