Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conform to PyTorch convention in the loss #163

Open
wiseodd opened this issue Apr 18, 2024 · 1 comment · May be fixed by #177
Open

Conform to PyTorch convention in the loss #163

wiseodd opened this issue Apr 18, 2024 · 1 comment · May be fixed by #177
Assignees
Milestone

Comments

@wiseodd
Copy link
Collaborator

wiseodd commented Apr 18, 2024

Currently we always assume that the class index in the outputs of the model to be -1. While this works for standard models and Huggingface LLMs, there are some important models where this is false. E.g. image-output models where the logit tensor is of shape (n_batch, n_classes, height, width), i.e. the class index is 1.

@wiseodd wiseodd added this to the 0.2 milestone Apr 18, 2024
@wiseodd wiseodd self-assigned this Apr 18, 2024
@wiseodd
Copy link
Collaborator Author

wiseodd commented Apr 26, 2024

The current idea is to add an arg in BaseLaplace: logit_class_idx: int = -1. Then, whenever Laplace flattens logits, it will use that as guidance.

Test cases for conv last layer, which will result in a logit of shape (batch_size, n_classes, height, width), should be created to cover this use case.

@wiseodd wiseodd modified the milestones: 0.2, 0.3 Jul 4, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant