@@ -986,7 +986,7 @@ REACTANT_ABI void *UnsafeBufferPointer(PjRtBuffer *buffer) {
986986}
987987
988988REACTANT_ABI void CopyToBuffer (PjRtClient *client, PjRtBuffer *buffer,
989- void *data, size_t offset, size_t size) {
989+ void *data, size_t offset, size_t size, PjRtBuffer **bufferP ) {
990990 if (buffer->IsOnCpu ()) {
991991 auto unsafe =
992992 (char *)MyValueOrThrow (buffer->client ()->UnsafeBufferPointer (buffer));
@@ -995,6 +995,16 @@ REACTANT_ABI void CopyToBuffer(PjRtClient *client, PjRtBuffer *buffer,
995995 // data, size);
996996 return ;
997997 }
998+
999+ auto pid = client->platform_id ();
1000+ if (pid == xla::TpuId ()) {
1001+ auto dims = argB->on_device_shape ().dimensions ();
1002+ auto buf2 = ArrayFromHostBuffer (client, data, buffer->element_type (), dims.size (), dims.data (), buffer->device ());
1003+ *bufferP = buf2;
1004+ PjRtBufferFree ((PjRtBuffer *)buffer);
1005+ return ;
1006+ }
1007+
9981008 auto raw_buffer =
9991009 MyValueOrThrow (PjRtRawBuffer::CreateRawAliasOfBuffer (buffer));
10001010 auto future = raw_buffer->CopyRawHostToDevice (data, offset, size);
@@ -1005,7 +1015,6 @@ REACTANT_ABI void CopyToBuffer(PjRtClient *client, PjRtBuffer *buffer,
10051015 return;
10061016 }
10071017
1008- auto pid = client->platform_id();
10091018 if (pid == xla::CudaId()) {
10101019 auto stream_client = (xla::PjRtStreamExecutorClient*)lrt->client;
10111020
@@ -1032,7 +1041,7 @@ REACTANT_ABI void CopyToBuffer(PjRtClient *client, PjRtBuffer *buffer,
10321041}
10331042
10341043REACTANT_ABI void CopyFromBuffer (PjRtClient *client, PjRtBuffer *buffer,
1035- void *data, size_t offset, size_t size) {
1044+ void *data, size_t offset, size_t size, PjRtBuffer **bufferP ) {
10361045 auto future = buffer->CopyRawToHost (data, offset, size);
10371046 future.Await ();
10381047#if 0
@@ -3147,14 +3156,14 @@ REACTANT_ABI void reactantXLAMemcpy(LinkableRuntime **__restrict__ lrtP,
31473156 break ;
31483157 case 1 : // cudaMemcpyHostToDevice
31493158 {
3150- auto &&[dstB, dstO, _ ] = bufferAndOffset (lrt, dst);
3151- CopyToBuffer (lrt->client , dstB, src, dstO, size);
3159+ auto &&[dstB, dstO, start ] = bufferAndOffset (lrt, dst);
3160+ CopyToBuffer (lrt->client , dstB, src, dstO, size, start );
31523161 break ;
31533162 }
31543163 case 2 : // cudaMemcpyDeviceToHost
31553164 {
3156- auto &&[srcB, srcO, _ ] = bufferAndOffset (lrt, src);
3157- CopyFromBuffer (lrt->client , srcB, dst, srcO, size);
3165+ auto &&[srcB, srcO, start ] = bufferAndOffset (lrt, src);
3166+ CopyFromBuffer (lrt->client , srcB, dst, srcO, size, start );
31583167 break ;
31593168 }
31603169 case 3 : // cudaMemcpyDeviceToDevice
0 commit comments