From 94ffb4d3e9305c569cb24cb2f0133561d6d44a29 Mon Sep 17 00:00:00 2001 From: Victor Dibia Date: Tue, 21 Jul 2020 22:43:42 -0700 Subject: [PATCH] v0.0 for general retriever specification in yaml #26 --- neuralqa/server/routehandlers.py | 134 +++++++++++++++---------------- neuralqa/server/routemodels.py | 2 +- 2 files changed, 67 insertions(+), 69 deletions(-) diff --git a/neuralqa/server/routehandlers.py b/neuralqa/server/routehandlers.py index bce7c00..3014365 100644 --- a/neuralqa/server/routehandlers.py +++ b/neuralqa/server/routehandlers.py @@ -30,24 +30,24 @@ async def get_answers(params: Answer): # switch to the selected model self._reader_pool.selected_model = params.reader - included_fields = ["name"] - search_query = { - "_source": included_fields, - "query": { - "multi_match": { - "query": params.question, - "fields": ["casebody.data.opinions.text", "name"] - } - }, - "highlight": { - "fragment_size": params.highlight_span, - "fields": { - "casebody.data.opinions.text": {"pre_tags": [""], "post_tags": [""]}, - "name": {} - } - }, - "size": params.max_documents - } + # included_fields = ["name"] + # search_query = { + # "_source": included_fields, + # "query": { + # "multi_match": { + # "query": params.question, + # "fields": ["casebody.data.opinions.text", "name"] + # } + # }, + # "highlight": { + # "fragment_size": params.highlight_span, + # "fields": { + # "casebody.data.opinions.text": {"pre_tags": [""], "post_tags": [""]}, + # "name": {} + # } + # }, + # "size": params.max_documents + # } answer_holder = [] response = {} @@ -62,25 +62,47 @@ async def get_answers(params: Answer): answer_holder.append(answer) # answer question based on retrieved passages from elastic search else: - query_result = self._retriever.run_query( - params.retriever, search_query) - if query_result["status"]: - query_result = query_result["result"] - for i, hit in enumerate(query_result["hits"]["hits"]): - if ("casebody.data.opinions.text" in hit["highlight"]): - # context passage is a concatenation of highlights - if (params.relsnip): - context = " .. ".join( - hit["highlight"]["casebody.data.opinions.text"]) - else: - context = str(hit["_source"]) - # print(context) - - answers = self._reader_pool.model.answer_question( - params.question, context, stride=params.tokenstride) - for answer in answers: - answer["index"] = i - answer_holder.append(answer) + body_field = "casebody" + secondary_fields = ["author"] + num_fragments = 5 + query_results = self._retriever.run_query(params.retriever, params.question, body_field, secondary_fields, + max_documents=params.max_documents, highlight_span=params.highlight_span, + relsnip=params.relsnip, num_fragments=num_fragments, highlight_tags=False) + + if (not query_results["status"]): + return query_results + # query_result = self._retriever.run_query( + # params.retriever, search_query) + docs = query_results["highlights"] if params.relsnip else query_results["docs"] + + for i, doc in enumerate(docs): + doc = doc.replace("\n", " ") + print(doc) + answers = self._reader_pool.model.answer_question( + params.question, doc, stride=params.tokenstride) + for answer in answers: + print(answer) + answer["index"] = i + answer_holder.append(answer) + + print(answer_holder) + # if query_result["status"]: + # query_result = query_result["result"] + # for i, hit in enumerate(query_result["hits"]["hits"]): + # if ("casebody.data.opinions.text" in hit["highlight"]): + # # context passage is a concatenation of highlights + # if (params.relsnip): + # context = " .. ".join( + # hit["highlight"]["casebody.data.opinions.text"]) + # else: + # context = str(hit["_source"]) + # # print(context) + + # answers = self._reader_pool.model.answer_question( + # params.question, context, stride=params.tokenstride) + # for answer in answers: + # answer["index"] = i + # answer_holder.append(answer) # sort answers by probability answer_holder = sorted( @@ -97,37 +119,13 @@ async def get_documents(params: Document): dictionary -- contains details on elastic search results. """ - included_fields = ["name"] - opinion_excerpt_length = 500 - - # return only included fields + script_field, - # limit response to top max_documents matches return highlights - search_query = { - "_source": included_fields, - "query": { - "multi_match": { - "query": params.question, - "fields": ["casebody.data.opinions.text", "name"] - } - }, - "script_fields": { - "opinion_excerpt": { - "script": "(params['_source']['casebody']['data']['opinions'][0]['text']).substring(0," + str(opinion_excerpt_length) + ")" - } - }, - "highlight": { - "fragment_size": params.highlight_span, - "fields": { - "casebody.data.opinions.text": {}, - "name": {} - } - }, - "size": params.max_documents - } - - query_result = self._retriever.run_query( - params.retriever, search_query) - return query_result + body_field = "casebody" + secondary_fields = ["author"] + num_fragments = 5 + query_results = self._retriever.run_query(params.retriever, params.question, body_field, secondary_fields, + max_documents=params.max_documents, highlight_span=params.highlight_span, relsnip=True, num_fragments=num_fragments) + + return query_results @router.post("/explain") async def get_explanation(params: Explanation): diff --git a/neuralqa/server/routemodels.py b/neuralqa/server/routemodels.py index 09cbe58..51cdb47 100644 --- a/neuralqa/server/routemodels.py +++ b/neuralqa/server/routemodels.py @@ -19,7 +19,7 @@ class Answer(BaseModel): question: str = "what is a fourth amendment right violation? " highlight_span: int = 250 tokenstride: int = 50 - context: str = "The fourth amendment kind of protects the rights of citizens .. such that they dont get searched" + context: Optional[str] = "The fourth amendment kind of protects the rights of citizens .. such that they dont get searched" reader: str = None relsnip: bool = True retriever: Optional[str] = "manual"