diff --git a/backend/serve.py b/backend/serve.py index 566342a..c82409f 100644 --- a/backend/serve.py +++ b/backend/serve.py @@ -28,14 +28,14 @@ def root(): @app.post("/imageFromPromt") -def image_from_promt(payload): +def image_from_promt(payload: PromptPayload): consumer = StableDiffusionConsumer() image_bytes = consumer.fetch_image(payload.prompt) return StreamingResponse(consumer.parse_to_bytesio(image_bytes), media_type="image/png") @app.post("/base64FromPrompt") -def base64_from_prompt(payload): +def base64_from_prompt(payload: PromptPayload): consumer = StableDiffusionConsumer() img_json = { "image": consumer.fetch_image(payload.prompt) @@ -44,7 +44,7 @@ def base64_from_prompt(payload): return JSONResponse(content=img_json) @app.post("/promptsFromText") -def prompts_from_text(text_payload): +def prompts_from_text(text_payload: TextPayload): prompt_extractor = PromptExtractor() prompts = prompt_extractor.extract_paragraphs_with_prompts(text_payload.text)