Skip to content

Commit

Permalink
fixed verbose to be integer, and some resulting cleanups
Browse files Browse the repository at this point in the history
  • Loading branch information
Matan Ninio committed Mar 19, 2024
1 parent 827c9ee commit 07d3adb
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ paths:
AA_tokenizer_json: "t5_tokenizer_AA_special.json"
SMILES_tokenizer_json: "bpe_tokenizer_trained_on_chembl_zinc_with_aug_4272372_samples_balanced_1_1.json"
cell_attributes_tokenizer_json: "cell_attributes_tokenizer.json"
modular_tokenizers_out_path: "${paths.tokenizers_path}/modular_AA_SMILES/"
modular_tokenizers_out_path: "${paths.tokenizers_path}/bmfm_modular_tokenizer/"
original_tokenizers_path: "${paths.tokenizers_path}"


Expand Down Expand Up @@ -34,3 +34,9 @@ data:
modular_json_path: "${paths.modular_tokenizers_out_path}/${paths.SMILES_tokenizer_json}"
start_delimiter: "<start_SMILES>" #String to start the sequence. If None or undefined, <start_${type key}> will be used as the name
end_delimiter: "<end_SMILES>" #String to end the sequence. If None or undefined, <end_${type key}> will be used as the name
- name: CELL_ATTRIBUTES #if None or undefined, type key will be used as the name
tokenizer_id: 2 #unique identifier of the tokenizer
json_path: "${paths.original_tokenizers_path}${paths.cell_attributes_tokenizer_json}"
modular_json_path: "${paths.modular_tokenizers_out_path}${paths.cell_attributes_tokenizer_json}"
start_delimiter: "<start_CELL_ATTRIBUTES>" #String to start the sequence. If None or undefined, <start_${type key}> will be used as the name
end_delimiter: "<end_CELL_ATTRIBUTES>" #String to end the sequence. If None or undefined, <end_${type key}> will be used as the name
17 changes: 11 additions & 6 deletions fusedrug/data/tokenizer/modulartokenizer/create_multi_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,13 @@ def test_tokenizer(
mode: Optional[str] = "",
input_strings: Optional[List] = None,
on_unknown: Optional[str] = "warn",
verbose: bool = False,
verbose: int = 0,
cfg_raw: Dict = None,
) -> None:
# for competability
if overall_max_length is None and cfg_raw is not None:
overall_max_length = cfg_raw["data"]["tokenizer"]["overall_max_len"]

if input_strings is None:
input_strings = [
TypedInput("AA", "<BINDING>ACDEFGHIJKLMNPQRSUVACDEF", 10),
Expand All @@ -47,7 +52,7 @@ def test_tokenizer(
max_len=overall_max_length,
return_overflow_info=True,
on_unknown=on_unknown,
verbose=1 if verbose else 0,
verbose=verbose,
)
if verbose:
print(f"encoded tokens: {enc.tokens}, overflow=[{overflow_msg}]")
Expand All @@ -56,7 +61,7 @@ def test_tokenizer(
typed_input_list=input_strings,
max_len=50,
on_unknown=on_unknown,
verbose=1 if verbose else 0,
verbose=verbose,
)
assert (
len(enc_pad.ids) == 50
Expand All @@ -67,7 +72,7 @@ def test_tokenizer(
enc_pad = t_inst.encode_list(
typed_input_list=input_strings,
on_unknown=on_unknown,
verbose=1 if verbose else 0,
verbose=verbose,
)
assert (
len(enc_pad.ids) == 70
Expand All @@ -78,7 +83,7 @@ def test_tokenizer(
enc_pad = t_inst.encode_list(
typed_input_list=input_strings,
on_unknown=on_unknown,
verbose=1 if verbose else 0,
verbose=verbose,
)
assert (
len(enc_pad.ids) == 70
Expand All @@ -89,7 +94,7 @@ def test_tokenizer(
typed_input_list=input_strings,
on_unknown=on_unknown,
max_len=15,
verbose=1 if verbose else 0,
verbose=verbose,
)
assert (
len(enc_trunc.ids) == 15
Expand Down
4 changes: 2 additions & 2 deletions fusedrug/data/tokenizer/modulartokenizer/modular_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -998,7 +998,7 @@ def encode_list(
pad_type_id: Optional[int] = None,
return_overflow_info: Optional[bool] = False,
on_unknown: Optional[str] = "warn",
verbose: Optional[int] = 1,
verbose: int = 1,
) -> Union[Encoding, Tuple[Encoding, str]]:
"""_summary_
Expand Down Expand Up @@ -1201,7 +1201,7 @@ def encode(
pad_type_id (Optional[int], optional): _description_. Defaults to 0.
return_overflow_info (Optional[bool], optional): _description_. If True return an additional string with overflow information. Defaults to False.
on_unknown: (Optional[str], optional): What happens if unknown tokens (i.e. ones mapped to <UNK>) are encountered: 'raise' or 'warn'
verbose (Optional[int], optional): verbosity level. 0: no notification, 1: warning notification, 2: warning with partial data, 3: warning
verbose (int, optional): verbosity level. 0: no notification, 1: warning notification, 2: warning with partial data, 3: warning
with full data. Defaults to 1.
Returns:
Encoding: _description_
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


def compare_modular_tokenizers(
tokenizer1_name: str, tokenizer2_name: str, verbose: bool = False
tokenizer1_name: str, tokenizer2_name: str, verbose: int = 0
) -> None:

pertrained_tokenizers_path = Path(__file__).parents[1] / "pretrained_tokenizers"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def _setup_test_env(self) -> None:


class TestModularTokenizer(unittest.TestCase):
def setUp(self, config_holder: ConfigHolder = None, verbose: bool = False) -> None:
def setUp(self, config_holder: ConfigHolder = None, verbose: int = 0) -> None:
if config_holder is None:
config_holder = ConfigHolder()
cfg = config_holder.get_config()
Expand Down Expand Up @@ -108,7 +108,7 @@ def test_tokenizer_with_exception(self) -> None:

@hydra.main(config_path=CONFIG_PATH, config_name=CONFIG_NAME, version_base=None)
def main(cfg: DictConfig) -> None:
verbose = cfg.get("verbose", False)
verbose = cfg.get("verbose", 0)
if verbose:
print(str(cfg))
config_holder = ConfigHolder(cfg)
Expand Down

0 comments on commit 07d3adb

Please sign in to comment.