Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

n_repeats and classifier confidence added to segmentation component #462

Merged
merged 1 commit into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 54 additions & 14 deletions app/javascript/projects/modelling/components/segment_component.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ import { ProjectProperties } from "."
import { TextControl } from "../controls/text"
import { BooleanTileGrid, NumericTileGrid } from "../tile_grid"
import { createXYZ } from "ol/tilegrid"
import { Point, Polygon } from "ol/geom"
import { Point } from "ol/geom"
import { Coordinate } from "ol/coordinate"

async function retrieveSegmentationMasks(prompts: string, confidence: string, projectProps: ProjectProperties) : Promise<any[]>{
async function retrieveSegmentationMasks(prompts: string, det_conf: string, clf_conf: string, n_repeats: string, projectProps: ProjectProperties, err: (err: string) => void) : Promise<any[]>{

const tileGrid = createXYZ()

Expand All @@ -18,18 +18,30 @@ async function retrieveSegmentationMasks(prompts: string, confidence: string, pr
const segs = await fetch("https://landscapes.wearepal.ai/api/v1/segment?" + new URLSearchParams(
{
labels: prompts,
confidence,
det_conf,
clf_conf,
n_repeats,
bbox: projectProps.extent.join(","),
layer: "rgb:full_mosaic_3857",
height: outputTileRange.getHeight().toString(),
width: outputTileRange.getWidth().toString(),
}
))

if(segs.status !== 200){
err(segs.statusText)
return []
}

const segsJson = await segs.json()

const preds = segsJson.predictions

if(preds === null){
err("No predictions found")
return []
}

const result = new BooleanTileGrid(
projectProps.zoom,
outputTileRange.minX,
Expand Down Expand Up @@ -103,8 +115,23 @@ export class SegmentComponent extends BaseComponent {

async builder(node: Node) {

if (!('confidence' in node.data)) {
node.data.confidence = "10"
node.meta.toolTip = "This node takes in 4 inputs: a prompt, a detector confidence, "
+"a classifier confidence, and the number of repeats. It then returns a segmentation mask, a detection box, "
+"and a confidence value. The prompt is the object you want to segment, detector confidence is the confidence "
+"threshold for the detector (it is recommended that this is set low for high recall), classifier confidence is "
+"the confidence threshold for the classifier (it is recommendeded that this is set higher for increased accuracy."
+" please note: setting this to 0 will disable this function), and the number of repeats is the number of times you want to repeat the segmentation process."

if (!('det_conf' in node.data)) {
node.data.det_conf = "5"
}

if (!('cls_conf' in node.data)) {
node.data.cls_conf = "75"
}

if (!('n_repeats' in node.data)) {
node.data.n_repeats = "5"
}

if (!('prompt' in node.data)) {
Expand All @@ -116,7 +143,9 @@ export class SegmentComponent extends BaseComponent {
node.addOutput(new Output('box', 'Detection Box', booleanDataSocket))

node.addControl(new TextControl(this.editor, 'prompt', 'Prompt', '500px'))
node.addControl(new TextControl(this.editor, 'confidence', 'Confidence (%)', '100px'))
node.addControl(new TextControl(this.editor, 'det_conf', 'Detector Confidence (%)', '100px'))
node.addControl(new TextControl(this.editor, 'cls_conf', 'Classifier Confidence (%)', '100px'))
node.addControl(new TextControl(this.editor, 'n_repeats', 'Repeats', '100px'))

}

Expand All @@ -126,19 +155,30 @@ export class SegmentComponent extends BaseComponent {
if (editorNode === undefined) { return }

const prompts = node.data.prompt as string
const confidence = node.data.confidence as string
const det_conf = node.data.det_conf as string
const cls_conf = node.data.cls_conf as string
const n_repeats = node.data.n_repeats as string

if (this.cache.has(`${prompts}_${confidence}%`)) {
const result = this.cache.get(`${prompts}_${confidence}%`)!
if (this.cache.has(`${prompts}_${cls_conf}%${det_conf}%${n_repeats}`)) {
const result = this.cache.get(`${prompts}_${cls_conf}%${det_conf}%${n_repeats}`)!
outputs['mask'] = result[0]
outputs['box'] = result[1]
outputs['conf'] = result[2]
}else{
const result = await retrieveSegmentationMasks(prompts, confidence, this.projectProps)
this.cache.set(`${prompts}_${confidence}%`, result)
outputs['mask'] = result[0]
outputs['box'] = result[1]
outputs['conf'] = result[2]
let nodeErr = ""
const result = await retrieveSegmentationMasks(prompts, det_conf, cls_conf, n_repeats, this.projectProps, (err) => {
nodeErr = err
})
if (result.length === 0) {
editorNode.meta.errorMessage = nodeErr
editorNode.update()
}else{
delete editorNode.meta.errorMessage
this.cache.set(`${prompts}_${cls_conf}%${det_conf}%${n_repeats}`, result)
outputs['mask'] = result[0]
outputs['box'] = result[1]
outputs['conf'] = result[2]
}
}

}
Expand Down
4 changes: 2 additions & 2 deletions app/javascript/projects/node_component.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@ export class NodeComponent extends Node {

return (
<div className={`node ${selected}`} style={{
boxShadow: "0px 0px 8px rgba(0, 0, 0, 0.25)",
boxShadow: "0px 0px 8px rgba(0, 0, 0, 0.25)",
background: "rgba(0, 0, 0, 0.5)",
color: "white",
borderRadius: "4px",
border: "solid 3px transparent",
border: node.meta.errorMessage ? "solid 2px rgba(210, 0, 0, .71)" : "solid 3px transparent",
cursor: "pointer",
minWidth: "250px",
height: "auto",
Expand Down
Loading