Skip to content

Conversation

@LittleHeroZZZX
Copy link
Contributor

@LittleHeroZZZX LittleHeroZZZX commented Jan 13, 2026

PR Docs

本 PR 修复差异文档与 PaConvert 中映射规则不一致的问题,具体描述见:PaddlePaddle/docs#7679

PR APIs

Fix validate consistency

- Update torch.nn.MultiheadAttention mapping to use ChangeAPIMatcher with paddle.compat.nn.MultiheadAttention
- Update torch.nn.functional.scaled_dot_product_attention mapping to use ChangeAPIMatcher with paddle.compat.nn.functional.scaled_dot_product_attention
- Add torch.Tensor.shape to global variable manager
- Add test case for torch.Tensor.shape with empty tensor
- Remove unsupport flags from MultiheadAttention tests and add parameter initialization
- Simplify API mapping configurations by removing detailed argument specifications
Changed the tolerance parameter from `rtol` (relative tolerance) to `atol` (absolute tolerance) in all test cases within `test_nn_MultiheadAttention.py`. This adjustment ensures more consistent and reliable test comparisons, particularly for values near zero where relative tolerance can be overly strict.
… attention tests

- Changed tolerance parameter from `rtol` to `atol` in all test cases for more consistent numerical comparisons
- Added new test cases (10-13) for GQA (Grouped Query Attention) and different sequence length scenarios
- Maintains same tolerance values but uses absolute tolerance instead of relative tolerance
@paddle-bot
Copy link

paddle-bot bot commented Jan 13, 2026

Thanks for your contribution!

"inplace"
]
},
"torch.nn.Module": {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这些配置先不用删,no_need_convert会覆盖api_mapping.json。后面会将no_need_convert列表统一迁移到api_mapping.json中来配置。no_need_convert名单里不用动。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Restored

return code


class CompositeMatcher(BaseMatcher):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个的功能是啥,和BaseMatcher一致?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed



def test_case_8():
# PaConvert can't tell if args in **kawrgs are supported by paddle or not.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个功能退化了?之前是可以加未转换标记的

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个应该是符合预期的,这个类的init 参数里只有**kwargs,不需要看用户到底传入了些啥参数,PaConvert把它原封不动传进去就行

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个应该是符合预期的,这个类的init 参数里只有**kwargs,不需要看用户到底传入了些啥参数,PaConvert把它原封不动传进去就行

这个case之前是可以跑通的,针对torchvision.models.inception_v3的转换做了修改吗

"torch.nn",
"torch.double",
"torch.cfloat",
"transformers.PreTrainedTokenizer",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个名单里可能还得加几个,baddbmm、round这些

"Matcher": "GenericMatcher",
"paddle_api": "paddle.newaxis"
},
"torch.nn": {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这些留着吧,非必要不删,有用户报过这种case:

if hasattr(torch.nn, 'xxx')

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Restored

Copy link
Collaborator

@zhwesky2010 zhwesky2010 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

目前NO_NEED_CONVERT这一套还是在globar_var里维护的,如果在api_mapping.json里配置。需要保证globar_var配置好,因为官网脚本目前还是会从globar_var里抓数据。

后面再开展:全量NO_NEED_CONVERT迁移到api_mapping.json+更新官网脚本。

"Matcher": "GenericMatcher",
"paddle_api": "paddle.Tensor.log_normal_",
"kwargs_change": {
"generator": ""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NoNeedConvertMatcher会处理generator、memory_format,可以计作NoNeedConvertMatcher



def test_case_8():
# PaConvert can't tell if args in **kawrgs are supported by paddle or not.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个应该是符合预期的,这个类的init 参数里只有**kwargs,不需要看用户到底传入了些啥参数,PaConvert把它原封不动传进去就行

这个case之前是可以跑通的,针对torchvision.models.inception_v3的转换做了修改吗

@LittleHeroZZZX
Copy link
Contributor Author

LittleHeroZZZX commented Jan 16, 2026

这个case之前是可以跑通的,针对torchvision.models.inception_v3的转换做了修改吗

这个 case 测试的是当存在不支持的参数时 PaConvert 提示不能转换。torchvision.models.inception_v3 的转换规则做了修改,把 **kargs 直接转给 paddle 对应的类,所以没办法识别出不支持的转换,符合预期,不影响对于参数都支持的情况的转换

"torchvision.models.inception_v3": {
"Matcher": "WeightsMatcher",
"paddle_api": "paddle.vision.models.inception_v3",
"args_list": [
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这些参数顺序是不是不对?文档中是

infoflow 2026-01-16 16-16-00

由于测试没有覆盖位置参数的用法,导致这里没有发现问题,后面单测还得补充增强。

Copy link
Contributor Author

@LittleHeroZZZX LittleHeroZZZX Jan 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里确实没有对齐,校验工具需要检查参数顺序。但是就这个 API 来说,看了下它的源码,他的第一个位置参数是同时支持是 pretained(bool)或者为 weights 参数的
1768808521

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个 API 后面改文档吧?pretrained 参数是弃用的用法,把 weights 放前面更合适

Copy link
Collaborator

@zhwesky2010 zhwesky2010 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

后面看下这一系列API的重载形式,确定docs和paconvert的最终版本。

"Matcher": "WeightsMatcher",
"paddle_api": "paddle.vision.models.inception_v3",
"args_list": [
"weights",
Copy link
Collaborator

@zhwesky2010 zhwesky2010 Jan 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

现在看起来不是前后问题,是overload的问题,可以按torch.std_mean的文档形式来,两个签名要二选一的话,优先适配其中公开推荐的用法。

@zhwesky2010 zhwesky2010 merged commit c44691d into PaddlePaddle:master Jan 19, 2026
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants