Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Question] two-tower-model + infoNCE how to optimize #718

Open
unshaven opened this issue Jun 5, 2024 · 1 comment
Open

[Question] two-tower-model + infoNCE how to optimize #718

unshaven opened this issue Jun 5, 2024 · 1 comment

Comments

@unshaven
Copy link

unshaven commented Jun 5, 2024

I have tried a two-tower model (user and query) in a real industrial scenario using contrastive learning. The samples are all actual click samples, and the loss function is InfoNCE. I have a few questions:

  1. The model performs best with only one layer, and the more MLP layers I add, the worse the HR@100 becomes.
  2. Using L2 normalization at the end of the model degrades performance.

As a result, I currently only have one MLP layer and no normalization. Could you please provide some advice or share some experiences on what I should do?

@rlcauvin
Copy link

rlcauvin commented Jun 6, 2024

Did you write your own implementation of the InfoNCE loss function, or are you using an existing implementation? I'm interested in trying it. Are you using it in your retrieval model or ranking model?

While I haven't used an InfoNCE loss function, for my retrieval and ranking models, I've found inverse time decay works really well to avoid overfitting, for example:

initial_learning_rate = 0.0007
decay_steps = cached_train_ds.cardinality().numpy()
decay_rate = 2.4

lr_schedule = tf.keras.optimizers.schedules.InverseTimeDecay(
  initial_learning_rate = initial_learning_rate,
  decay_steps = decay_steps,
  decay_rate = decay_rate,
  staircase = False)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants