You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Currently we always assume that the class index in the outputs of the model to be -1. While this works for standard models and Huggingface LLMs, there are some important models where this is false. E.g. image-output models where the logit tensor is of shape (n_batch, n_classes, height, width), i.e. the class index is 1.
The text was updated successfully, but these errors were encountered:
The current idea is to add an arg in BaseLaplace: logit_class_idx: int = -1. Then, whenever Laplace flattens logits, it will use that as guidance.
Test cases for conv last layer, which will result in a logit of shape (batch_size, n_classes, height, width), should be created to cover this use case.
Currently we always assume that the class index in the outputs of the model to be
-1
. While this works for standard models and Huggingface LLMs, there are some important models where this is false. E.g. image-output models where the logit tensor is of shape(n_batch, n_classes, height, width)
, i.e. the class index is1
.The text was updated successfully, but these errors were encountered: