Skip to content

Conversation

@jacenfox
Copy link
Collaborator

It's a bug fixed version.

Copy link
Owner

@MathGaron MathGaron left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is already implemented in the framework. Check the cat_dog exemple :
in your network init :
Basically you pass a lambda function that takes the gradient as input and outputs whatever
and pass a tag name so you can retrieve it (fc1 in this case)
self.hook_fc1 = self.hook_generator(lambda x: float(torch.sqrt(torch.sum(x ** 2))), "fc1")
in your network forward :
if self.training: x.register_hook(self.hook_fc1)
and to retrieve the gradient in your callback:
grad_log = state.model.grad_data["fc1"]

I am sorry about that, I should really take the time to document it properly....

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants