Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Neel/spatial map answer extraction #57

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
8 changes: 4 additions & 4 deletions eureka_ml_insights/data_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from .prompt_processing import JinjaPromptTemplate
from .spatial_utils import (
ExtractAnswerGrid,
ExtractAnswerMaze,
ExtractAnswerSpatialMap,
ExtractAnswerSpatialMapAndMaze,
ExtractQuestionOptions,
)
from .transform import (
AddColumn,
Expand Down Expand Up @@ -69,8 +69,8 @@
ASTEvalTransform,
PrependStringTransform,
ExtractAnswerGrid,
ExtractAnswerSpatialMap,
ExtractAnswerMaze,
ExtractAnswerSpatialMapAndMaze,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add "ExtractQuestionOptions" so that when we run the formatters this import does not get removed (considered unused).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ahh, yes will do

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done and checked in

ExtractQuestionOptions,
ShuffleColumnsTransform,
ColumnMatchMapTransform,
TokenCounterTransform,
Expand Down
275 changes: 122 additions & 153 deletions eureka_ml_insights/data_utils/spatial_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,25 +157,34 @@ def extract_answer_from_text_grid(text, question_type):
return None # Return None if no numbers are found


def extract_answer_from_text_map(text, question_type, model_name):
def extract_answer_from_text_map_and_maze(model_output_raw, options):
"""
Extracts the answer from the text based on specific patterns,
and as a fallback, extracts the first number if no patterns match.
The code is from: https://github.com/alvinmingwisc/spatial_reason_vlm/tree/main/eval,
and included with minimal modifications.
Extracts the answer from the text based on known model output patterns.
Searches for both a letter and whole word answer and returns both as they are not
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh is this why you have added the OR metrics? If so, I don't think this justifies adding new metric classes, instead the answer extractor should return some combination of these two, maybe simply "x or y".

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to do this another way. Especially if that means less new code. However, if let's say I return "x or y", I still need to a metric that knows how to check this. the current ones just look for case-sensitive or insensitive exact match.

One though was to make one metric that can take a single or a list (instead of having two as now) and an optional parameter for how to combine, i.e. any ("or") or all ("and"). what do you think?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now returning the "x or y" string

always consistent.

Args:
- text (str): The text containing the model's answer.
- question_type (str): The text containing the question type.
- model_name (str): The model name.
- model_output_raw (str): The text containing the model's answer.
- options (str): The list of options.

Returns:
- str or None: The extracted answer, or None if no answer could be extracted.
- str or None: The extracted answers, or empty strings if no answer could be extracted.
"""
# Mapping of textual numbers to their numeric equivalents

# replace common subsitutions in model outputs

model_output_parsed_letter = ""
model_output_parsed = ""

if not model_output_raw:
return [model_output_parsed, model_output_parsed_letter]

model_output_raw = re.sub(r"\bno objects\b", "0 objects", model_output_raw, re.IGNORECASE)
model_output_raw = re.sub(r"\bnot\b", "no", model_output_raw, re.IGNORECASE)
model_output_raw = re.sub(r"\bshould be\b", "is", model_output_raw, re.IGNORECASE)

number_mapping = {
"zero": 0,
"no": 0,
"zero": 0,
"one": 1,
"two": 2,
"three": 3,
Expand All @@ -187,127 +196,71 @@ def extract_answer_from_text_map(text, question_type, model_name):
"nine": 9,
}

dirs = ["southeast", "northeast", "northwest", "southwest"]
dir_pattern = rf"\b({'|'.join(dirs)})\b"

if text is None:
return None

question_id = int(re.search("[0-9]", re.search("Q[0-9]", question_type).group()).group())

if question_id == 0:
direction_match = re.search(r"\b[A-D]\.\s*(" + "|".join(dirs) + r")\b", text, re.IGNORECASE)
if direction_match:
return direction_match.group(1).lower()

match = re.search(dir_pattern, text, re.IGNORECASE)
if match:
return match.group(1)
return None

elif question_id == 1:
match = re.search(
rf"^([\w\s\'\']+?)\s+is\s+(?:located\s+|in\s+the\s+|located\s+to\s+the\s+)({dir_pattern})",
text,
re.IGNORECASE,
)

if match:
string = match.group(1)
return string

match = re.search(r"\b[A-D]\.\s*(.*)", text) # problem with extracting .

if match:
string = match.group(1)
string = remove_redundancy(string)
string = extract_before_is(string)
return string

match = re.search(r"\b([ABCD][.,]|[(][abcdABCD][)])\s*(.*?)(?=\sis\b|\.|,|<|$)", text)
if match:
answer = match.group(1).strip()
# Remove trailing punctuation if any
answer = re.sub(r"[\.,\?!<]+$", "", answer)
return answer

match = re.search(
rf"Therefore, the object in the {dir_pattern} of [\w\s\'\']+ is ([\w\s\'\']+)", text, re.IGNORECASE
)
if match:
string = match.group(2)
return string

if "claude" in model_name.lower():
match = re.search(rf"^([\w\s\'\']+?)\s+is\s+(to\s+the\s+)({dir_pattern})", text, re.IGNORECASE)
if match:
string = match.group(1)
return string

if "gemini" in model_name.lower():
patterns = [
rf"\*\*Concise Answer:\*\*\n([\w\s\'\']+?)\s+is\s+(?:located\s+|in\s+the\s+|in\s+|located\s+to\s+the\s+)({dir_pattern})",
rf"\*\*Answer:\*\*\s+([\w\s\'\']+?)\s+is\s+in\s+the\s+({dir_pattern})\s+of\s+([\w\s\'\']+)",
r"\*\*Answer:\*\*\n([\w\s\'\']+)",
r"\*\*Answer\*\*:\s+([\w\s\'\']+)",
r"\*\*Answer:\*\*\s+([\w\s\'\']+)",
]

for pattern in patterns:
match = re.search(pattern, text, re.IGNORECASE)
if match:
return match.group(1)

if "gpt-4o" in model_name.lower() or "gpt4o" in model_name.lower():
match = re.search(
rf"Concise Answer:\s+([\w\s\'\']+?)\s+is\s+(?:located\s+|in\s+the\s+|in\s+|located\s+to\s+the\s+)({dir_pattern})",
text,
re.IGNORECASE,
)
if match:
string = match.group(1)
return string

# If no match, check for an answer following "is", with specific end markers defined
match = re.search(r"\bis\b\s+(.*?)(?=\.|,|<|$)", text)
if match:
answer = match.group(1).strip()
# Remove trailing punctuation if any
answer = re.sub(r"[\.,\?!<]+$", "", answer)
return answer

return None # Return None if no match is found

elif question_id == 2:
match = re.search(r"\b[A-D]\.\s*(\d+)", text) # match number only
if match:
return match.group(1)
# Create a list to store all found numbers along with their positions
found_numbers = []

# Check for textual numbers and their positions
for text_num, num in number_mapping.items():
for match in re.finditer(rf"\b{text_num}\b", text, re.IGNORECASE):
found_numbers.append((match.start(), num))

# Check for digit sequences and their positions, specifically ignoring list markers at the start
# Exclude numbers following "\n\n" and directly followed by ". "
text = re.sub(r"^\n\n\d+\.\s", "", text) # Remove the leading list marker if it exists

for match in re.finditer(r"\d+", text):
found_numbers.append((match.start(), int(match.group(0))))

# Sort found numbers by their positions (smallest position first)
if found_numbers:
found_numbers.sort(key=lambda x: x[0])
# Return the number associated with the earliest position
return str(found_numbers[0][1])
return None

else:
raise ValueError(f"Question ID {question_id} is not supported.")

return None # Return None if no numbers are found
for k, v in number_mapping.items():
model_output_raw = re.sub(rf"\b{k}\b", str(v), model_output_raw, re.IGNORECASE)

# get dict of options from options string
options_dict = {x.split(".")[0].strip().lower():x.split(".")[1].strip().lower() for x in options}


model_output_parsed_letter = ""
model_output_parsed = ""

answers = [v for k, v in options_dict.items()]
answers_pattern = rf"\b({'|'.join(answers)})\b"

if "Answer:".lower() in model_output_raw.lower():
pattern_letter = r"^\**Answer:\**\s+(\w)\. (\w+)"
matches = re.search(pattern_letter, model_output_raw, re.IGNORECASE)
if matches:
match_option = matches.group(1).lower()
neelsj marked this conversation as resolved.
Show resolved Hide resolved
if match_option in options_dict:
model_output_parsed_letter = options_dict[match_option]
else:
model_output_parsed_letter = match_option

pattern_phrase = r"Answer:\**\s+([^\n]+)"
matches = re.search(pattern_phrase, model_output_raw, re.IGNORECASE)
if matches:
model_output_answer_line = matches.group(1)

answers_match = re.search(answers_pattern, model_output_answer_line, re.IGNORECASE)

if answers_match:
model_output_parsed = answers_match.group(1)
else:
letters = [k for k, v in options_dict.items()]
letters_pattern = rf"\b({'|'.join(letters)})\b"
letters_pattern_match = re.search(letters_pattern, model_output_answer_line, re.IGNORECASE)

if letters_pattern_match:
match_option = letters_pattern_match.group(1).lower()
model_output_parsed_letter = options_dict[match_option]

elif "answer is".lower() in model_output_raw.lower():
pattern_letter = r'answer is:*\s*\**([\w\d]+)[\s:.]*\**'

# first look for a single letter answer
matches = re.search(pattern_letter, model_output_raw, re.IGNORECASE)
if matches:
match_option = matches.group(1).lower()
if match_option in options_dict:
model_output_parsed_letter = options_dict[match_option]
else:
model_output_parsed_letter = match_option

# next look if any of the options names are present in the first line

model_output_answer_line = model_output_raw.splitlines()[0]

answers = [v for k, v in options_dict.items()]
answers_pattern = rf"\b({'|'.join(answers)})\b"
answers_match = re.search(answers_pattern, model_output_answer_line, re.IGNORECASE)

if answers_match:
model_output_parsed = answers_match.group(1)

return model_output_parsed + " or " + model_output_parsed_letter


def extract_answer_from_text_maze(text, question_type):
Expand Down Expand Up @@ -443,43 +396,59 @@ def transform(self, df: pd.DataFrame) -> pd.DataFrame:
)
return df


@dataclass
class ExtractAnswerGrid(ExtractAnswer):
"""This class is an answer extractor for the GRID benchmark."""
class ExtractQuestionOptions(DFTransformBase):
"""This class is for extracting the option list from a prompt."""

answer_column_name: str
extracted_answer_column_name: str
question_type_column_name: str
mode: str
prompt_column_name: str
extracted_options_column_name: str

@abstractmethod
def _parse_answer_function(self, answer_text, question_type):
return extract_answer_from_text_grid(answer_text, question_type)
def _extract_options_from_text_map(self, prompt):
"""
Extracts the multiple-choice options list from the text.

Args:
- text (str): The text containing the prompt.

Returns:
- str or None: The extracted list of options.
"""

# get list of options from prompt
prompt_lines = prompt.splitlines()
matches = [i for i, x in enumerate(prompt_lines) if "Available options:" in x]
options = prompt_lines[matches[0]+1:matches[0]+5]

return options

def transform(self, df: pd.DataFrame) -> pd.DataFrame:
df[self.extracted_options_column_name] = df[self.prompt_column_name].apply(self._extract_options_from_text_map)
return df

@dataclass
class ExtractAnswerSpatialMap(ExtractAnswer):
"""This class is an answer extractor for the SPATIAL_MAP benchmark."""
class ExtractAnswerGrid(ExtractAnswer):
"""This class is an answer extractor for the GRID benchmark."""

answer_column_name: str
extracted_answer_column_name: str
question_type_column_name: str
model_name: str
mode: str

@abstractmethod
def _parse_answer_function(self, answer_text, question_type):
return extract_answer_from_text_map(answer_text, question_type, self.model_name)
return extract_answer_from_text_grid(answer_text, question_type)


@dataclass
class ExtractAnswerMaze(ExtractAnswer):
"""This class is an answer extractor for the MAZE benchmark."""
class ExtractAnswerSpatialMapAndMaze(DFTransformBase):
"""This class is an answer extractor for the SPATIAL_MAP and MAZE benchmark."""

answer_column_name: str
extracted_answer_column_name: str
question_type_column_name: str
extracted_options_column_name: str

@abstractmethod
def _parse_answer_function(self, answer_text, question_type):
return extract_answer_from_text_maze(answer_text, question_type)
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
df[self.extracted_answer_column_name] = df.apply(
lambda x: extract_answer_from_text_map_and_maze(x[self.answer_column_name], x[self.extracted_options_column_name]), axis=1
)
return df
19 changes: 12 additions & 7 deletions eureka_ml_insights/user_configs/vision_language/maze.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
ColumnRename,
DataLoader,
DataReader,
ExtractAnswerMaze,
ExtractQuestionOptions,
ExtractAnswerSpatialMapAndMaze,
PrependStringTransform,
SequenceTransform,
)
from eureka_ml_insights.metrics import CaseInsensitiveMatch, CountAggregator
from eureka_ml_insights.metrics import SubstringExistsMatch, CountAggregator

from eureka_ml_insights.configs import (
AggregatorConfig,
Expand Down Expand Up @@ -81,23 +82,27 @@ def configure_pipeline(self, model_config: ModelConfig, resume_from: str = None)
"format": ".jsonl",
"transform": SequenceTransform(
[
ExtractQuestionOptions(
prompt_column_name="prompt",
extracted_options_column_name="target_options_answers",
),
ColumnRename(name_mapping={"model_output": "model_output_raw"}),
ExtractAnswerMaze(
ExtractAnswerSpatialMapAndMaze(
answer_column_name="model_output_raw",
extracted_answer_column_name="model_output",
question_type_column_name="question_type",
extracted_options_column_name="target_options_answers",
),
],
),
},
),
metric_config=MetricConfig(CaseInsensitiveMatch),
metric_config=MetricConfig(SubstringExistsMatch),
aggregator_configs=[
AggregatorConfig(CountAggregator, {"column_names": ["CaseInsensitiveMatch_result"], "normalize": True}),
AggregatorConfig(CountAggregator, {"column_names": ["SubstringExistsMatch_result"], "normalize": True}),
AggregatorConfig(
CountAggregator,
{
"column_names": ["CaseInsensitiveMatch_result"],
"column_names": ["SubstringExistsMatch_result"],
"group_by": "task",
"normalize": True,
},
Expand Down
Loading
Loading