-
Notifications
You must be signed in to change notification settings - Fork 182
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Florence-2 | Default LoRA config produces suboptimal results | Found a better config #162
Comments
Hi @patel-zeel great analysis!
Absolutely. We just need to develop a consistent system where users can do this both through the SDK and the CLI. |
Thank you, @SkalskiP. Yes, that'd be great. Having a library that allows fine-tuning VLMs with a single line of CLI is fantastic. I'd happily contribute now or later when the library is relatively stable. |
One of our main goals right now is to design the solution in a way that remains flexible enough to handle similar fine-tuning scenarios in the future. Your input and code would be tremendously appreciated—whether it’s refining the design, writing new features, or improving documentation. Let’s work together to make it as robust and reusable as possible! I see 2 main ways we can implement this: 1: Extend Florence2Configuration with LoRA settings.SDK Examplefrom maestro.trainer.models.florence_2.core import train, Florence2Configuration
config = Florence2Configuration(
dataset="dataset/location",
epochs=10,
batch_size=4,
optimization_strategy="lora",
metrics=["edit_distance"],
lora_r=8,
lora_alpha=16,
lora_dropout=0.05,
lora_bias="none",
lora_init_lora_weights="gaussian",
lora_use_rslora=True,
lora_inference_mode=False
)
train(config) CLI Examplemaestro florence_2 train \
--dataset "dataset/location" \
--epochs 10 \
--batch-size 4 \
--optimization_strategy "lora" \
--metrics "edit_distance" \
--lora-r 8 \
--lora-alpha 16 \
--lora-dropout 0.05 \
--lora-bias "none" \
--lora-init-lora-weights "gaussian" \
--lora-use-rslora True \
--lora-inference-mode False Approach 2: Single "advanced" parameter argument with inlined JSON/YAMLAdd a single parameter (e.g., SDK Examplefrom maestro.trainer.models.florence_2.core import train, Florence2Configuration
advanced_params = {
"r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"bias": "none",
"init_lora_weights": "gaussian",
"use_rslora": True
}
config = Florence2Configuration(
dataset="dataset/location",
epochs=10,
batch_size=4,
optimization_strategy="lora",
metrics=["edit_distance"],
peft_advanced_params=advanced_params
)
train(config) CLI Examplemaestro florence_2 train \
--dataset "dataset/location" \
--epochs 10 \
--batch-size 4 \
--optimization_strategy "lora" \
--metrics "edit_distance" \
--peft-advanced-params '{"r":8,"lora_alpha":16,"lora_dropout":0.05,"bias":"none","use_rslora":true}' @PawelPeczek-Roboflow @Matvezy @probicheaux what do you think about this? |
I am leaning toward the second approach so that the API doesn't need to be changed to accommodate every minor change. |
👍 on schema-less configs plus it would be great to have ability to load configs from files and override configs with params explicitly given in CLI command - this way you can have a base config that you quickly modify at will while running a training |
@PawelPeczek-Roboflow @SkalskiP Considering the base config idea, it seems that hierarchical config could be unintuitive to users if they just want to change a single parameter, e.g., Default values hardcoded in the library
SDK Example for one parameter changefrom maestro.trainer.models.florence_2.core import train, Florence2Configuration
config = Florence2Configuration(
dataset="dataset/location",
lora_alpha = 16, # overrides the default value
)
train(config) CLI Example for one parameter change
|
Hi @patel-zeel 👋🏻 sorry for the lack of contact over the past few days. I had to focus on other projects than maestro for a while, but I'm coming back. I talked to @Matvezy @probicheaux in private messages and I think I'm leaning towards solution number 2, potentially adding support for config files that @PawelPeczek-Roboflow suggested in the future (not as part of this task). @patel-zeel would you like to work on the implementation? |
No worries, @SkalskiP!
Sure. Do you mean we don't want to add the config file support just yet but first enable dictionary like support suggested in solution 2? |
Exactly!👍🏻 |
Search before asking
Bug
The default LoRA config used in
maestro
ismaestro/maestro/trainer/models/florence_2/checkpoints.py
Lines 50 to 57 in cecc78f
LoRA config used in the Florence-2 fine-tuning on custom dataset Roboflow notebook is the following:
For the
poker-cards-fmjio
dataset, the default LoRA config ofmaestro
results in a mAP50 value of 0.20, but the Roboflow notebook config results in a mAP50 value of 0.52. I experimentally found a config that results in a mAP50 value of 0.71. Please seeMinimal Reproducible Example
for more.Environment
Minimal Reproducible Example
I used 3 variants of LoRA config and results are as described below:
Configs
Maestro default
Maestro default + Gaussian init
Roboflow notebook default
Roboflow notebook default except lora_alpha=16
Metrics
I used the Roboflow notebook to run the pipeline for 10 epochs and compute the metrics. I have used the new evaluation API as follows:
Results
Conclusion
Using lora_alpha=16 in
Roboflow notebook default
LoRA config results in much better performance with same number of epochs.Questions
toml
orjson
or any other format file and then users provide the path of the config to maestro CLI.Additional
No response
Are you willing to submit a PR?
The text was updated successfully, but these errors were encountered: