Skip to content

Commit dd54725

Browse files
committed
add relevance endpoint
1 parent 4439145 commit dd54725

File tree

3 files changed

+27
-0
lines changed

3 files changed

+27
-0
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ wheels/
2424
.installed.cfg
2525
*.egg
2626
MANIFEST
27+
src/
2728

2829
# PyInstaller
2930
# Usually these files are written by a python script from a template

matscholar/rest.py

+18
Original file line numberDiff line numberDiff line change
@@ -301,10 +301,28 @@ def search_text_with_ents(self, text, filters, cutoff=1000):
301301

302302
return self._make_request(sub_url, payload=payload, method=method)
303303

304+
def classify_relevance(self, docs, decision_boundary=0.5):
305+
"""
306+
Determine whether or not a document relates to inorganic material science.
307+
308+
:param docs: list of strings; the documents to be classified
309+
:param decision_boundary: float; decision boundary for the classifier
310+
:return: list; classification labels for each doc (1 or 0)
311+
"""
312+
method = "POST"
313+
sub_url = "/relevance"
314+
payload = {
315+
"docs": docs,
316+
"decision_boundary": decision_boundary
317+
}
318+
319+
return self._make_request(sub_url, payload=payload, method=method)
320+
304321

305322
class MatScholarRestError(Exception):
306323
"""
307324
Exception class for MatstractRester.
308325
Raised when the query has problems, e.g., bad query format.
309326
"""
310327
pass
328+

matscholar/tests/test_rest.py

+8
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,14 @@ def test_normalized(self):
224224
self.assertEqual(tagged_docs[0][0][2][1], "MAT")
225225
self.assertTrue(isinstance(tagged_docs[0][0][2][0], list))
226226

227+
def test_relevance(self):
228+
test_docs = ["The polymer may be used in OLEDs and biosensors.",
229+
"The band gap of ZnO is 3.3 eV"]
230+
preds = self.rester.classify_relevance(test_docs)
231+
self.assertEqual(preds[0], 0)
232+
self.assertEqual(preds[1], 1)
233+
234+
227235
class MaterialSearchEntsTest(unittest.TestCase):
228236

229237
rester = Rester()

0 commit comments

Comments
 (0)