Skip to content

Commit

Permalink
v0.0 for general retriever specification in yaml #26
Browse files Browse the repository at this point in the history
  • Loading branch information
victordibia committed Jul 22, 2020
1 parent a3b131e commit 94ffb4d
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 69 deletions.
134 changes: 66 additions & 68 deletions neuralqa/server/routehandlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,24 +30,24 @@ async def get_answers(params: Answer):
# switch to the selected model
self._reader_pool.selected_model = params.reader

included_fields = ["name"]
search_query = {
"_source": included_fields,
"query": {
"multi_match": {
"query": params.question,
"fields": ["casebody.data.opinions.text", "name"]
}
},
"highlight": {
"fragment_size": params.highlight_span,
"fields": {
"casebody.data.opinions.text": {"pre_tags": [""], "post_tags": [""]},
"name": {}
}
},
"size": params.max_documents
}
# included_fields = ["name"]
# search_query = {
# "_source": included_fields,
# "query": {
# "multi_match": {
# "query": params.question,
# "fields": ["casebody.data.opinions.text", "name"]
# }
# },
# "highlight": {
# "fragment_size": params.highlight_span,
# "fields": {
# "casebody.data.opinions.text": {"pre_tags": [""], "post_tags": [""]},
# "name": {}
# }
# },
# "size": params.max_documents
# }

answer_holder = []
response = {}
Expand All @@ -62,25 +62,47 @@ async def get_answers(params: Answer):
answer_holder.append(answer)
# answer question based on retrieved passages from elastic search
else:
query_result = self._retriever.run_query(
params.retriever, search_query)
if query_result["status"]:
query_result = query_result["result"]
for i, hit in enumerate(query_result["hits"]["hits"]):
if ("casebody.data.opinions.text" in hit["highlight"]):
# context passage is a concatenation of highlights
if (params.relsnip):
context = " .. ".join(
hit["highlight"]["casebody.data.opinions.text"])
else:
context = str(hit["_source"])
# print(context)

answers = self._reader_pool.model.answer_question(
params.question, context, stride=params.tokenstride)
for answer in answers:
answer["index"] = i
answer_holder.append(answer)
body_field = "casebody"
secondary_fields = ["author"]
num_fragments = 5
query_results = self._retriever.run_query(params.retriever, params.question, body_field, secondary_fields,
max_documents=params.max_documents, highlight_span=params.highlight_span,
relsnip=params.relsnip, num_fragments=num_fragments, highlight_tags=False)

if (not query_results["status"]):
return query_results
# query_result = self._retriever.run_query(
# params.retriever, search_query)
docs = query_results["highlights"] if params.relsnip else query_results["docs"]

for i, doc in enumerate(docs):
doc = doc.replace("\n", " ")
print(doc)
answers = self._reader_pool.model.answer_question(
params.question, doc, stride=params.tokenstride)
for answer in answers:
print(answer)
answer["index"] = i
answer_holder.append(answer)

print(answer_holder)
# if query_result["status"]:
# query_result = query_result["result"]
# for i, hit in enumerate(query_result["hits"]["hits"]):
# if ("casebody.data.opinions.text" in hit["highlight"]):
# # context passage is a concatenation of highlights
# if (params.relsnip):
# context = " .. ".join(
# hit["highlight"]["casebody.data.opinions.text"])
# else:
# context = str(hit["_source"])
# # print(context)

# answers = self._reader_pool.model.answer_question(
# params.question, context, stride=params.tokenstride)
# for answer in answers:
# answer["index"] = i
# answer_holder.append(answer)

# sort answers by probability
answer_holder = sorted(
Expand All @@ -97,37 +119,13 @@ async def get_documents(params: Document):
dictionary -- contains details on elastic search results.
"""

included_fields = ["name"]
opinion_excerpt_length = 500

# return only included fields + script_field,
# limit response to top max_documents matches return highlights
search_query = {
"_source": included_fields,
"query": {
"multi_match": {
"query": params.question,
"fields": ["casebody.data.opinions.text", "name"]
}
},
"script_fields": {
"opinion_excerpt": {
"script": "(params['_source']['casebody']['data']['opinions'][0]['text']).substring(0," + str(opinion_excerpt_length) + ")"
}
},
"highlight": {
"fragment_size": params.highlight_span,
"fields": {
"casebody.data.opinions.text": {},
"name": {}
}
},
"size": params.max_documents
}

query_result = self._retriever.run_query(
params.retriever, search_query)
return query_result
body_field = "casebody"
secondary_fields = ["author"]
num_fragments = 5
query_results = self._retriever.run_query(params.retriever, params.question, body_field, secondary_fields,
max_documents=params.max_documents, highlight_span=params.highlight_span, relsnip=True, num_fragments=num_fragments)

return query_results

@router.post("/explain")
async def get_explanation(params: Explanation):
Expand Down
2 changes: 1 addition & 1 deletion neuralqa/server/routemodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class Answer(BaseModel):
question: str = "what is a fourth amendment right violation? "
highlight_span: int = 250
tokenstride: int = 50
context: str = "The fourth amendment kind of protects the rights of citizens .. such that they dont get searched"
context: Optional[str] = "The fourth amendment kind of protects the rights of citizens .. such that they dont get searched"
reader: str = None
relsnip: bool = True
retriever: Optional[str] = "manual"
Expand Down

0 comments on commit 94ffb4d

Please sign in to comment.