From a961841f659f0f3eeee689ff865954f782a40635 Mon Sep 17 00:00:00 2001 From: Paul Crossley Date: Thu, 5 Dec 2024 10:27:26 +0000 Subject: [PATCH] n_repeats and classifier confidence added to segmentation component --- .../modelling/components/segment_component.ts | 68 +++++++++++++++---- app/javascript/projects/node_component.tsx | 4 +- 2 files changed, 56 insertions(+), 16 deletions(-) diff --git a/app/javascript/projects/modelling/components/segment_component.ts b/app/javascript/projects/modelling/components/segment_component.ts index 3c6fd1e..c9417a5 100644 --- a/app/javascript/projects/modelling/components/segment_component.ts +++ b/app/javascript/projects/modelling/components/segment_component.ts @@ -6,10 +6,10 @@ import { ProjectProperties } from "." import { TextControl } from "../controls/text" import { BooleanTileGrid, NumericTileGrid } from "../tile_grid" import { createXYZ } from "ol/tilegrid" -import { Point, Polygon } from "ol/geom" +import { Point } from "ol/geom" import { Coordinate } from "ol/coordinate" -async function retrieveSegmentationMasks(prompts: string, confidence: string, projectProps: ProjectProperties) : Promise{ +async function retrieveSegmentationMasks(prompts: string, det_conf: string, clf_conf: string, n_repeats: string, projectProps: ProjectProperties, err: (err: string) => void) : Promise{ const tileGrid = createXYZ() @@ -18,7 +18,9 @@ async function retrieveSegmentationMasks(prompts: string, confidence: string, pr const segs = await fetch("https://landscapes.wearepal.ai/api/v1/segment?" + new URLSearchParams( { labels: prompts, - confidence, + det_conf, + clf_conf, + n_repeats, bbox: projectProps.extent.join(","), layer: "rgb:full_mosaic_3857", height: outputTileRange.getHeight().toString(), @@ -26,10 +28,20 @@ async function retrieveSegmentationMasks(prompts: string, confidence: string, pr } )) + if(segs.status !== 200){ + err(segs.statusText) + return [] + } + const segsJson = await segs.json() const preds = segsJson.predictions + if(preds === null){ + err("No predictions found") + return [] + } + const result = new BooleanTileGrid( projectProps.zoom, outputTileRange.minX, @@ -103,8 +115,23 @@ export class SegmentComponent extends BaseComponent { async builder(node: Node) { - if (!('confidence' in node.data)) { - node.data.confidence = "10" + node.meta.toolTip = "This node takes in 4 inputs: a prompt, a detector confidence, " + +"a classifier confidence, and the number of repeats. It then returns a segmentation mask, a detection box, " + +"and a confidence value. The prompt is the object you want to segment, detector confidence is the confidence " + +"threshold for the detector (it is recommended that this is set low for high recall), classifier confidence is " + +"the confidence threshold for the classifier (it is recommendeded that this is set higher for increased accuracy." + +" please note: setting this to 0 will disable this function), and the number of repeats is the number of times you want to repeat the segmentation process." + + if (!('det_conf' in node.data)) { + node.data.det_conf = "5" + } + + if (!('cls_conf' in node.data)) { + node.data.cls_conf = "75" + } + + if (!('n_repeats' in node.data)) { + node.data.n_repeats = "5" } if (!('prompt' in node.data)) { @@ -116,7 +143,9 @@ export class SegmentComponent extends BaseComponent { node.addOutput(new Output('box', 'Detection Box', booleanDataSocket)) node.addControl(new TextControl(this.editor, 'prompt', 'Prompt', '500px')) - node.addControl(new TextControl(this.editor, 'confidence', 'Confidence (%)', '100px')) + node.addControl(new TextControl(this.editor, 'det_conf', 'Detector Confidence (%)', '100px')) + node.addControl(new TextControl(this.editor, 'cls_conf', 'Classifier Confidence (%)', '100px')) + node.addControl(new TextControl(this.editor, 'n_repeats', 'Repeats', '100px')) } @@ -126,19 +155,30 @@ export class SegmentComponent extends BaseComponent { if (editorNode === undefined) { return } const prompts = node.data.prompt as string - const confidence = node.data.confidence as string + const det_conf = node.data.det_conf as string + const cls_conf = node.data.cls_conf as string + const n_repeats = node.data.n_repeats as string - if (this.cache.has(`${prompts}_${confidence}%`)) { - const result = this.cache.get(`${prompts}_${confidence}%`)! + if (this.cache.has(`${prompts}_${cls_conf}%${det_conf}%${n_repeats}`)) { + const result = this.cache.get(`${prompts}_${cls_conf}%${det_conf}%${n_repeats}`)! outputs['mask'] = result[0] outputs['box'] = result[1] outputs['conf'] = result[2] }else{ - const result = await retrieveSegmentationMasks(prompts, confidence, this.projectProps) - this.cache.set(`${prompts}_${confidence}%`, result) - outputs['mask'] = result[0] - outputs['box'] = result[1] - outputs['conf'] = result[2] + let nodeErr = "" + const result = await retrieveSegmentationMasks(prompts, det_conf, cls_conf, n_repeats, this.projectProps, (err) => { + nodeErr = err + }) + if (result.length === 0) { + editorNode.meta.errorMessage = nodeErr + editorNode.update() + }else{ + delete editorNode.meta.errorMessage + this.cache.set(`${prompts}_${cls_conf}%${det_conf}%${n_repeats}`, result) + outputs['mask'] = result[0] + outputs['box'] = result[1] + outputs['conf'] = result[2] + } } } diff --git a/app/javascript/projects/node_component.tsx b/app/javascript/projects/node_component.tsx index 547252d..4dcba90 100644 --- a/app/javascript/projects/node_component.tsx +++ b/app/javascript/projects/node_component.tsx @@ -36,11 +36,11 @@ export class NodeComponent extends Node { return (