Skip to content

Commit

Permalink
Merge pull request #36 from victordibia/dev
Browse files Browse the repository at this point in the history
Update retriever to use config.yaml
  • Loading branch information
victordibia authored Jul 23, 2020
2 parents 0d48a2d + 71094c8 commit 31c4a93
Show file tree
Hide file tree
Showing 20 changed files with 416 additions and 207 deletions.
6 changes: 5 additions & 1 deletion .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,15 @@ jobs:
- uses: actions/checkout@v2

# Example of using a custom build-command.

- uses: ammaraskar/sphinx-action@master
with:
docs-folder: "docs/"

# Create an artifact of the html output.
- uses: actions/upload-artifact@v1
with:
name: DocumentationHTML
path: docs/_build/html/

# Publish built docs to gh-pages branch.
# ===============================
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ neuralqa/data
docs/images/extra
config.yaml
.vscode
neuralqa/server/test_query.py

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
## NeuralQA: A Usable Library for (Extractive) Question Answering on Large Datasets with BERT

[![License: MIT](https://img.shields.io/github/license/victordibia/neuralqa?style=flat-square)](https://opensource.org/licenses/MIT)
[![License: MIT](https://img.shields.io/github/license/victordibia/neuralqa)](https://opensource.org/licenses/MIT)
![docs](https://github.com/victordibia/neuralqa/workflows/docs/badge.svg?style=flat-square)

> Still in **alpha**, lots of changes anticipated.
Expand Down Expand Up @@ -35,6 +35,8 @@ NeuralQA is comprised of several high level modules:

- **Reader**: For each retrieved passage, a BERT based model predicts a span that contains the answer to the question. In practice, retrieved passages may be lengthy and BERT based models can process a maximum of 512 tokens at a time. NeuralQA handles this in two ways. Lengthy passages are chunked into smaller sections with a configurable stride. Secondly, NeuralQA offers the option of extracting a subset of relevant snippets (RelSnip) which a BERT reader can then scan to find answers. Relevant snippets are portions of the retrieved document that contain exact match results for the search query.

- **Expander**: Methods for generating additional queries to improve recall.

- **User Interface**: NeuralQA provides a visual user interface for performing queries (manual queries where question and context are provided as well as queries over a search index), viewing results and also sensemaking of results (reranking of passages based on answer scores, highlighting keyword match, model explanations).

## Configuration
Expand Down Expand Up @@ -78,7 +80,6 @@ reader:
## Documentation
An attempt is being made to better document NeuralQA here - [https://victordibia.github.io/neuralqa/](https://victordibia.github.io/neuralqa/).
## Citation
Expand Down
8 changes: 5 additions & 3 deletions neuralqa/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from neuralqa.utils import cli_args
from neuralqa.utils import import_case_data, ConfigParser
import os
from neuralqa.retriever import RetrieverPool


@click.group()
Expand All @@ -17,9 +18,10 @@ def cli():
@cli_args.WORKERS
@cli_args.CONFIG_PATH
def test(host, port, workers, config_path):
config = ConfigParser(config_path)
print(config.config["reader"]["models"])
# import_case_data()
# config = ConfigParser(config_path)
# rp = RetrieverPool(config.config["retriever"])
# print((rp.retriever_pool))
import_case_data()


@cli.command()
Expand Down
29 changes: 24 additions & 5 deletions neuralqa/config_default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ ui:
samples: True # show/hide sample question answer pairs
passages: True # show/hide passages which are retrieved
explanations: True # show/hide explanations button
allanswers: True # show all answers or just the best answer (based on probability score)
allanswers: False # show all answers or just the best answer (based on probability score)
expander: False
options:
stride:
Expand Down Expand Up @@ -44,16 +44,14 @@ ui:
title: Highlight Span
selected: 250
options:
- name: 150
value: 150
- name: 250
value: 250
- name: 350
value: 350
- name: 450
value: 450
- name: 650
value: 650
- name: 850
value: 850
samples:

retriever:
Expand All @@ -63,20 +61,35 @@ retriever:
options:
- name: Manual
value: manual
type: manual

- name: Case Law
value: cases
host: localhost
port: 9200
username: None
password: None
type: elasticsearch
fields:
body_field: casebody.data.opinions.text
- name: Medical
value: medical
host: localhost
port: 9200
username: None
password: None
type: elasticsearch
fields:
body_field: context
- name: Supreme Court
value: supremecourt
host: localhost
port: 9200
username: None
password: None
type: elasticsearch
fields:
body_field: casebody
readtopn: 0

relsnip:
Expand All @@ -99,6 +112,12 @@ reader:
- name: DistilBERT SQUAD2
value: twmkn9/distilbert-base-uncased-squad2
type: distilbert
- name: BERT SQUAD2
value: deepset/bert-base-cased-squad2
type: bert
# - name: Medical BERT SQUAD2 # example of a model with local files
# value: /Users/victordibia/Downloads/meddistilbert
# type: bert

expander:
title: Expander
Expand Down
12 changes: 11 additions & 1 deletion neuralqa/reader/bertreader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import time
import logging

logger = logging.getLogger(__name__)


class BERTReader(Reader):
def __init__(self, model_name, model_path, model_type="bert", **kwargs):
Expand Down Expand Up @@ -46,7 +48,7 @@ def get_chunk_answer_span(self, inputs):
"probability": str(answer_end_softmax_probability + answer_start_softmax_probability)
}

def token_chunker(self, question, context, max_chunk_size=512, stride=2):
def token_chunker(self, question, context, max_chunk_size=512, stride=2, max_num_chunks=5):
# we tokenize question and context once.
# if question + context > max chunksize, we break it down into multiple chunks of question +
# subsets of context with some stride overlap
Expand All @@ -59,7 +61,13 @@ def token_chunker(self, question, context, max_chunk_size=512, stride=2):
chunk_size = max_chunk_size - len(question_tokens) - 1
# -1 for the 102 end token we append later
current_pos = 0
chunk_count = 0
while current_pos < len(context_tokens) and current_pos >= 0:

# we want to cap the number of chunks we create
if max_num_chunks and chunk_count >= max_num_chunks:
break

end_point = current_pos + \
chunk_size if (current_pos + chunk_size) < len(context_tokens) - \
1 else len(context_tokens) - 1
Expand Down Expand Up @@ -87,6 +95,8 @@ def token_chunker(self, question, context, max_chunk_size=512, stride=2):
"token_type_ids": token_type_ids
})
current_pos = current_pos + chunk_size - stride + 1
chunk_count +=1

return chunk_holder

def answer_question(self, question, context, max_chunk_size=512, stride=70):
Expand Down
7 changes: 5 additions & 2 deletions neuralqa/reader/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,16 @@
logging.getLogger("transformers.modeling_tf_utils").setLevel(logging.ERROR)


logger = logging.getLogger(__name__)


class Reader:
def __init__(self, model_name, model_path, model_type, **kwargs):
self.load_model(model_name, model_path, model_type)

def load_model(self, model_name, model_path, model_type):
logging.info(" >> loading HF model " +
model_name + " from " + model_path)
logger.info(">> loading HF model " +
model_name + " from " + model_path)
self.type = model_type
self.name = model_name
self.tokenizer = AutoTokenizer.from_pretrained(
Expand Down
4 changes: 3 additions & 1 deletion neuralqa/reader/readerpool.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from neuralqa.reader import BERTReader
import logging

logger = logging.getLogger(__name__)


class ReaderPool():
def __init__(self, models):
Expand All @@ -26,6 +28,6 @@ def selected_model(self, selected_model):
self._selected_model = selected_model
else:
default_model = next(iter(self.reader_pool))
logging.info(
logger.info(
">> Model you are attempting to use %s does not exist in model pool. Using the following default model instead %s ", selected_model, default_model)
self._selected_model = default_model
1 change: 1 addition & 0 deletions neuralqa/retriever/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .retriever import *
from .elasticsearchretriever import *
from .retrieverpool import *
93 changes: 65 additions & 28 deletions neuralqa/retriever/elasticsearchretriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,46 +4,83 @@


logging.getLogger("elasticsearch").setLevel(logging.CRITICAL)
logger = logging.getLogger(__name__)


class ElasticSearchRetriever(Retriever):
def __init__(self, index_type="elasticsearch", host="localhost", port=9200):
def __init__(self, index_type="elasticsearch", host="localhost", port=9200, **kwargs):
Retriever.__init__(self, index_type)

self.es = Elasticsearch([{'host': host, 'port': port}])
self.username = ""
self.password = ""
self.body_field = ""
self.host = host
self.port = port

allowed_keys = list(self.__dict__.keys())
self.__dict__.update((k, v)
for k, v in kwargs.items() if k in allowed_keys)
self.es = Elasticsearch([{'host': self.host, 'port': self.port}])
self.isAvailable = self.es.ping()

def run_query(self, index_name, search_query):
"""Makes a query to the elastic search server with the given search_query parameters.
Also returns opinion_excerpt script field, which is a substring of the first opinion in the case
rejected_keys = set(kwargs.keys()) - set(allowed_keys)

if rejected_keys:
raise ValueError(
"Invalid arguments in ElasticSearchRetriever constructor:{}".format(rejected_keys))

def run_query(self, index_name, search_query, max_documents=5, highlight_span=100, relsnip=True, num_fragments=5, highlight_tags=True):

tags = {"pre_tags": [""], "post_tags": [
""]} if not highlight_tags else {}
highlight_params = {
"fragment_size": highlight_span,
"fields": {
self.body_field: tags
},
"number_of_fragments": num_fragments
}

Arguments:
search_query {[dictionary]} -- [contains a dictionary that corresponds to an elastic search query on the ]
search_query = {
"_source":{"includes": [self.body_field]},
"query": {
"multi_match": {
"query": search_query,
"fields": [self.body_field]
}
},
"size": max_documents
}

Returns:
[dictionary] -- [dictionary of results from elastic search.]
"""
query_result = None
status = True
results = {}

if (relsnip):
# search_query["_source"] = {"includes": [""]}
search_query["highlight"] = highlight_params
# else:
# search_query["_source"] = {"includes": [self.body_field]}

# return error as result on error.
# Calling function should check status before parsing result
try:
query_result = {
"status": True,
"result": self.es.search(index=index_name, body=search_query)
}
except ConnectionRefusedError as e:
query_result = {
"status": False,
"result": str(e)
}
except Exception as e:
query_result = {
"status": False,
"result": str(e)
}
return query_result
query_result = self.es.search(
index=index_name, body=search_query)
# RelSnip: for each document, we concatenate all
# fragments in each document and return as the document.
highlights = [" ".join(hit["highlight"][self.body_field])
for hit in query_result["hits"]["hits"] if "highlight" in hit]
docs = [((hit["_source"][self.body_field]))
for hit in query_result["hits"]["hits"] if hit["_source"]]
took = query_result["took"]
results = {"took": took, "highlights": highlights, "docs": docs}

except (ConnectionRefusedError, Exception) as e:
status = False
results["errormsg"] = str(e)

results["status"] = status
return results


def test_connection(self):
try:
self.es.cluster.health()
Expand Down
40 changes: 40 additions & 0 deletions neuralqa/retriever/retrieverpool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@

from neuralqa.retriever import ElasticSearchRetriever
import logging

logger = logging.getLogger(__name__)


class RetrieverPool():
def __init__(self, retrievers):

self.retriever_pool = {}
for retriever in retrievers["options"]:
if (retriever["value"] in self.retriever_pool):
raise ValueError(
"Duplicate retriever value : {} ".format(retriever["value"]))

if (retriever["type"] == "elasticsearch"):
self.retriever_pool[retriever["value"]] = ElasticSearchRetriever(
host=retriever["host"], port=retriever["port"], body_field=retriever["fields"]["body_field"])
if (retriever["type"] == "solr"):
logger.info("We do not yet support Solr retrievers")
self.selected_retriever = retrievers["selected"]

@property
def retriever(self):
return self.retriever_pool[self.selected_retriever]

@property
def selected_retriever(self):
return self._selected_retriever

@selected_retriever.setter
def selected_retriever(self, selected_retriever):
if (selected_retriever in self.retriever_pool):
self._selected_retriever = selected_retriever
else:
default_retriever = next(iter(self.retriever_pool))
logger.info(
">> Retriever you are attempting to use (%s) does not exist in retriever pool. Using the following default retriever instead %s ", selected_retriever, default_retriever)
self._selected_retriever = default_retriever
Loading

0 comments on commit 31c4a93

Please sign in to comment.