@@ -162,9 +162,10 @@ void deleter(DLManagedTensor * arg) {
162162 delete static_cast <ATenDLMTensor*>(arg->manager_ctx );
163163}
164164
165- DLManagedTensor* toManagedDLPack (const torch::Tensor& src ) {
165+ DLManagedTensor* toManagedDLPack (const torch::Tensor& src_ ) {
166166 ATenDLMTensor * atDLMTensor (new ATenDLMTensor);
167- atDLMTensor->handle = src;
167+ atDLMTensor->handle = src_;
168+ auto & src = atDLMTensor->handle ;
168169 atDLMTensor->tensor .manager_ctx = atDLMTensor;
169170 atDLMTensor->tensor .deleter = &deleter;
170171 atDLMTensor->tensor .dl_tensor .data = src.data_ptr ();
@@ -238,19 +239,19 @@ void torchRunModule(ModuleContext* ctx, const char* fnName,
238239 }
239240
240241 if (stack[i].isTensor ()) {
241- outputs[count++] = toManagedDLPack (stack[i].toTensor ().to (output_device));
242+ outputs[count++] = toManagedDLPack (stack[i].toTensor ().contiguous (). to (output_device));
242243 }
243244 else if (stack[i].isTensorList ()) {
244245 auto list = stack[i].toTensorList ();
245246 for (size_t j=0 ; j<list.size (); j++) {
246- outputs[count++] = toManagedDLPack (list.get (j).to (output_device));
247+ outputs[count++] = toManagedDLPack (list.get (j).contiguous (). to (output_device));
247248 }
248249 }
249250 else if (stack[i].isTuple ()) {
250251 auto & elements = stack[i].toTuple ()->elements ();
251252 for (size_t j=0 ; j<elements.size (); j++) {
252253 if (elements[j].isTensor ()) {
253- outputs[count++] = toManagedDLPack (elements[j].toTensor ().to (output_device));
254+ outputs[count++] = toManagedDLPack (elements[j].toTensor ().contiguous (). to (output_device));
254255 }
255256 else {
256257 throw std::runtime_error (std::string (" Function returned non-tensor values" ) + fnName);
0 commit comments