|
4 | 4 | from fastapi import APIRouter, Depends, Header, HTTPException, Request |
5 | 5 | from loguru import logger |
6 | 6 |
|
| 7 | +from prompting.datasets.random_website import DDGDatasetEntry |
7 | 8 | from prompting.llms.model_zoo import ModelZoo |
8 | 9 | from prompting.rewards.scoring import task_scorer |
9 | 10 | from prompting.tasks.inference import InferenceTask |
| 11 | +from prompting.tasks.web_retrieval import WebRetrievalTask |
10 | 12 | from shared.base import DatasetEntry |
11 | 13 | from shared.dendrite import DendriteResponseEvent |
12 | 14 | from shared.epistula import SynapseStreamResult |
@@ -37,22 +39,54 @@ async def score_response(request: Request, api_key_data: dict = Depends(validate |
37 | 39 | uid = int(payload.get("uid")) |
38 | 40 | chunks = payload.get("chunks") |
39 | 41 | llm_model = ModelZoo.get_model_by_id(model) if (model := body.get("model")) else None |
40 | | - task_scorer.add_to_queue( |
41 | | - task=InferenceTask( |
42 | | - messages=[msg["content"] for msg in body.get("messages")], |
43 | | - llm_model=llm_model, |
44 | | - llm_model_id=body.get("model"), |
45 | | - seed=int(body.get("seed", 0)), |
46 | | - sampling_params=body.get("sampling_params", {}), |
47 | | - ), |
48 | | - response=DendriteResponseEvent( |
49 | | - uids=[uid], |
50 | | - stream_results=[SynapseStreamResult(accumulated_chunks=[chunk for chunk in chunks if chunk is not None])], |
51 | | - timeout=shared_settings.NEURON_TIMEOUT, |
52 | | - ), |
53 | | - dataset_entry=DatasetEntry(), |
54 | | - block=shared_settings.METAGRAPH.block, |
55 | | - step=-1, |
56 | | - task_id=str(uuid.uuid4()), |
57 | | - ) |
58 | | - logger.info("Organic tas appended to scoring queue") |
| 42 | + task = body.get("task") |
| 43 | + if task == "InferenceTask": |
| 44 | + logger.info(f"Received Organic InferenceTask with body: {body}") |
| 45 | + task_scorer.add_to_queue( |
| 46 | + task=InferenceTask( |
| 47 | + messages=[msg["content"] for msg in body.get("messages")], |
| 48 | + llm_model=llm_model, |
| 49 | + llm_model_id=body.get("model"), |
| 50 | + seed=int(body.get("seed", 0)), |
| 51 | + sampling_params=body.get("sampling_params", {}), |
| 52 | + ), |
| 53 | + response=DendriteResponseEvent( |
| 54 | + uids=[uid], |
| 55 | + stream_results=[ |
| 56 | + SynapseStreamResult(accumulated_chunks=[chunk for chunk in chunks if chunk is not None]) |
| 57 | + ], |
| 58 | + timeout=shared_settings.NEURON_TIMEOUT, |
| 59 | + ), |
| 60 | + dataset_entry=DatasetEntry(), |
| 61 | + block=shared_settings.METAGRAPH.block, |
| 62 | + step=-1, |
| 63 | + task_id=str(uuid.uuid4()), |
| 64 | + ) |
| 65 | + elif task == "WebRetrievalTask": |
| 66 | + logger.info(f"Received Organic WebRetrievalTask with body: {body}") |
| 67 | + try: |
| 68 | + search_term = body.get("messages")[0].get("content") |
| 69 | + except Exception as ex: |
| 70 | + logger.error(f"Failed to get search term from messages: {ex}, can't score WebRetrievalTask") |
| 71 | + return |
| 72 | + |
| 73 | + task_scorer.add_to_queue( |
| 74 | + task=WebRetrievalTask( |
| 75 | + messages=[msg["content"] for msg in body.get("messages")], |
| 76 | + seed=int(body.get("seed", 0)), |
| 77 | + sampling_params=body.get("sampling_params", {}), |
| 78 | + query=search_term, |
| 79 | + ), |
| 80 | + response=DendriteResponseEvent( |
| 81 | + uids=[uid], |
| 82 | + stream_results=[ |
| 83 | + SynapseStreamResult(accumulated_chunks=[chunk for chunk in chunks if chunk is not None]) |
| 84 | + ], |
| 85 | + timeout=shared_settings.NEURON_TIMEOUT, |
| 86 | + ), |
| 87 | + dataset_entry=DDGDatasetEntry(search_term=search_term), |
| 88 | + block=shared_settings.METAGRAPH.block, |
| 89 | + step=-1, |
| 90 | + task_id=str(uuid.uuid4()), |
| 91 | + ) |
| 92 | + logger.info("Organic task appended to scoring queue") |
0 commit comments