Skip to content

Commit

Permalink
feat: add the chess move validity dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
DriesSmit committed May 7, 2024
1 parent 3a1c07e commit 2e2aab0
Show file tree
Hide file tree
Showing 7 changed files with 121 additions and 1 deletion.
5 changes: 5 additions & 0 deletions debatellm/eval/eval_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,11 @@ def eval_system(
file_handler,
dataset_settings,
)
elif dataset == "chess":
questions, format_question_fn = load_datasets.chess_questions(
file_handler,
dataset_settings,
)
else:
raise ValueError(f"Dataset {dataset} not supported.")

Expand Down
77 changes: 77 additions & 0 deletions debatellm/eval/load_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,3 +528,80 @@ def format_question(d: dict) -> str:
return question

return questions, format_question


# flake8: noqa: C901
def chess_questions(
file_handler: FileHandler,
dataset_settings: Dict,
seed: Optional[int] = 42,
) -> Tuple[list, Callable]:
"""Load MedMCQA dataset and question format function.
Args:
file_handler: file handler for s3 or local files.
dataset_settings : dictionary containing the file
name of the exam to evaluate on.
Returns:
questions : list of questions.
format_question : function to format question correctly.
"""

# Fix the random seed to ensure fair comparison and reproducibility
random.seed(seed) # Set the seed to a fixed value

exam_type = dataset_settings["exam_type"]
assert exam_type == "short"

data_path = f"task.json"
json_obj = file_handler.read_json_file(path=data_path)

all_targets = [" ".join(example["target"]) for example in json_obj["examples"]]

questions = []
for i, q in enumerate(json_obj["examples"]):
q_text = "For the following (in-progress) chess games, please select the option that shows all the legal distination squares that completes the notation: " + q["input"]
correct_answer = " ".join(q["target"])

# Generate distractors from other example targets
distractors = [ans for idx, ans in enumerate(all_targets) if idx != i]

unique_distractors_found = False
# Loop until we find three unique distractors that are not the correct answer
while not unique_distractors_found:
random.shuffle(distractors) # Shuffle distractors to randomize selection
selected_distractors = distractors[:3] # Select first three shuffled as distractors
unique_distractors_found = all(
[distractor != correct_answer for distractor in selected_distractors]
)

options = [correct_answer] + selected_distractors
random.shuffle(options) # Shuffle options to prevent predictable positions

# Identify the correct solution after shuffling
solution_index = options.index(correct_answer)
solution = chr(65 + solution_index) # Map index to A, B, C, D for multiple choice

question = {
"no": i,
"question": q_text,
"options": {chr(65 + j): opt for j, opt in enumerate(options)},
"solution": solution,
"subcategory": "Chess Moves",
"category": "Game Analysis"
}

questions.append(question)

def format_question(d: dict) -> str:
question = ""
options = d["options"]
for key in ["category", "subcategory", "question"]:
if d[key]:
question += f"\n{key}: {d[key]}"
for k, v in options.items():
question += f"\n{k}: {v}"
return question

return questions, format_question
6 changes: 6 additions & 0 deletions debatellm/utils/s3_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,12 @@ def read_json_lines(self, path: str) -> List[Dict]:
data.append(json.loads(line))
return data

def read_json_file(self, path: str):
full_path = os.path.join(self.bucket_path, path)
with open(full_path, 'r') as file:
data = json.load(file)
return data

def save_json(self, path: str, data: Dict) -> None:
"""Wrapper around the python json save method
Expand Down
2 changes: 1 addition & 1 deletion experiments/conf/config.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
defaults:
- system: single_agent # [options: single_agent, debateqa, google_mad, ensemble_refinement,
# tsinghua_mad, spp_synergy, chateval, medprompt]
- dataset: medqa # [options: usmle, medmcqa, mmlu, pubmedqa, medqa, ciar, cosmosqa, gpqa]
- dataset: chess # [options: usmle, medmcqa, mmlu, pubmedqa, medqa, ciar, cosmosqa, gpqa, chess]
- _self_

max_eval_count: 1
Expand Down
4 changes: 4 additions & 0 deletions experiments/conf/dataset/chess.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
eval_dataset: "chess"
path_to_exams: "data/datasets/${.eval_dataset}/"
dataset_settings:
exam_type: "short" # options: [short]
3 changes: 3 additions & 0 deletions scripts/download_eval_datasets.sh
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,6 @@ scripts/eval_datasets/download_cosmosqa.sh

# download gpqa dataset
scripts/eval_datasets/download_gpqa.sh

# download chess dataset
scripts/eval_datasets/download_chess.sh
25 changes: 25 additions & 0 deletions scripts/eval_datasets/download_chess.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#!/bin/bash

# Check for AICHOR_INPUT_PATH and set FOLDER accordingly
if [[ -z "${AICHOR_INPUT_PATH}" ]]; then
FOLDER=data/datasets/chess/
else
FOLDER="${AICHOR_INPUT_PATH}data/datasets/chess/"
fi

# Check if the dataset already exists
if [ -e "${FOLDER}task.json" ]; then
echo "CHESS dataset already exists in $FOLDER"
else
echo "Downloading CHESS dataset to $FOLDER"

# Ensure the target directory exists
if [ ! -d "${FOLDER}" ]; then
echo "Creating folder ${FOLDER}"
mkdir -p "${FOLDER}"
fi

# Download the specific part of the dataset
wget https://raw.githubusercontent.com/google/BIG-bench/main/bigbench/benchmark_tasks/chess_state_tracking/real_short/task.json -P $FOLDER
echo "Dataset downloaded to $FOLDER"
fi

0 comments on commit 2e2aab0

Please sign in to comment.