Skip to content

Commit df0382c

Browse files
committed
internal/llm: add interface PolicyChecker
A PolicyChecker is used to check the inputs and outputs of an LLM against safety policies. We will implement this interface with the GCP Checks Guardrails API. For #70 Change-Id: I52a776cc94900cef4c0b56f284c56e44f5136d4b Reviewed-on: https://go-review.googlesource.com/c/oscar/+/637975 Reviewed-by: Hyang-Ah Hana Kim <[email protected]> LUCI-TryBot-Result: Go LUCI <[email protected]>
1 parent ceb8002 commit df0382c

File tree

1 file changed

+124
-0
lines changed

1 file changed

+124
-0
lines changed

internal/llm/policy_checker.go

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
// Copyright 2024 The Go Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
package llm
6+
7+
import (
8+
"context"
9+
"fmt"
10+
)
11+
12+
// A PolicyChecker checks inputs and outputs to LLMs against
13+
// safety policies.
14+
type PolicyChecker interface {
15+
// SetPolicies sets the policies to evaluate in subsequent
16+
// calls to [Check]. If unset, use the implementation's default.
17+
SetPolicies([]*PolicyConfig)
18+
// CheckText evaluates the policies configured on this [PolicyChecker]
19+
// against the given text and returns a result for each [PolicyConfig].
20+
// If the text represents a model output, the prompt parts used to generate it
21+
// may optionally be provided as context. If the text represents a model input,
22+
// prompt should be empty.
23+
CheckText(ctx context.Context, text string, prompt ...Part) ([]*PolicyResult, error)
24+
}
25+
26+
// A PolicyConfig is a policy to apply to an input or output to an LLM.
27+
//
28+
// Copied from "google.golang.org/api/checks/v1alpha" to avoid direct dependency.
29+
type PolicyConfig struct {
30+
// PolicyType: Required. Type of the policy.
31+
PolicyType PolicyType
32+
// Threshold: Optional. Score threshold to use when deciding if the content is
33+
// violative or non-violative. If not specified, the default 0.5 threshold for
34+
// the policy will be used.
35+
Threshold float64
36+
}
37+
38+
// A PolicyResult is the result of evaluating a policy against
39+
// an input or output to an LLM.
40+
//
41+
// Copied from "google.golang.org/api/checks/v1alpha" to avoid direct dependency.
42+
type PolicyResult struct {
43+
// PolicyType: Type of the policy.
44+
PolicyType PolicyType
45+
// Score: Final score for the results of this policy.
46+
Score float64
47+
// ViolationResult: Result of the classification for the policy.
48+
ViolationResult ViolationResult
49+
}
50+
51+
type PolicyType string
52+
53+
// Possible values for [PolicyType].
54+
const (
55+
// Default.
56+
PolicyTypeUnspecified = PolicyType("POLICY_TYPE_UNSPECIFIED")
57+
// The model facilitates, promotes or enables access to
58+
// harmful goods, services, and activities.
59+
PolicyTypeDangerousContent = PolicyType("DANGEROUS_CONTENT")
60+
// The model reveals an individual’s personal
61+
// information and data.
62+
PolicyTypePIISolicitingReciting = PolicyType("PII_SOLICITING_RECITING")
63+
// The model generates content that is malicious,
64+
// intimidating, bullying, or abusive towards another individual.
65+
PolicyTypeHarassment = PolicyType("HARASSMENT")
66+
// The model generates content that is sexually
67+
// explicit in nature.
68+
PolicyTypeSexuallyExplicit = PolicyType("SEXUALLY_EXPLICIT")
69+
// The model promotes violence, hatred, discrimination on the
70+
// basis of race, religion, etc.
71+
PolicyTypeHateSpeech = PolicyType("HATE_SPEECH")
72+
// The model provides or offers to facilitate access to
73+
// medical advice or guidance.
74+
PolicyTypeMedicalInfo = PolicyType("MEDICAL_INFO")
75+
// The model generates content that contains
76+
// gratuitous, realistic descriptions of violence or gore.
77+
PolicyTypeViolenceAndGore = PolicyType("VIOLENCE_AND_GORE")
78+
// The model generates profanity and obscenities.
79+
PolicyTypeObscenityAndProfanity = PolicyType("OBSCENITY_AND_PROFANITY")
80+
)
81+
82+
// AllPolicyTypes returns a policy that, when passed to
83+
// to [PolicyChecker.SetPolicies], configures the PolicyChecker
84+
// to check for all available dangerous content types at the default threshold.
85+
func AllPolicyTypes() []*PolicyConfig {
86+
return []*PolicyConfig{
87+
{PolicyType: PolicyTypeDangerousContent},
88+
{PolicyType: PolicyTypePIISolicitingReciting},
89+
{PolicyType: PolicyTypeHarassment},
90+
{PolicyType: PolicyTypeSexuallyExplicit},
91+
{PolicyType: PolicyTypeHateSpeech},
92+
{PolicyType: PolicyTypeMedicalInfo},
93+
{PolicyType: PolicyTypeViolenceAndGore},
94+
{PolicyType: PolicyTypeObscenityAndProfanity},
95+
}
96+
}
97+
98+
type ViolationResult string
99+
100+
// Possible values for [ViolationResult].
101+
const (
102+
// Unspecified result.
103+
ViolationResultUnspecified = ViolationResult("VIOLATION_RESULT_UNSPECIFIED")
104+
// The final score is greater or equal the input score
105+
// threshold.
106+
ViolationResultViolative = ViolationResult("VIOLATIVE")
107+
// The final score is smaller than the input score
108+
// threshold.
109+
ViolationResultNonViolative = ViolationResult("NON_VIOLATIVE")
110+
// There was an error and the violation result could
111+
// not be determined.
112+
ViolationResultClassificationError = ViolationResult("CLASSIFICATION_ERROR")
113+
)
114+
115+
// IsViolative reports whether the policy result represents
116+
// a violated policy.
117+
func (pr *PolicyResult) IsViolative() bool {
118+
return pr.ViolationResult == ViolationResultViolative
119+
}
120+
121+
// String returns a string representation of the policy result.
122+
func (pr *PolicyResult) String() string {
123+
return fmt.Sprintf("%s: %s (%f)", pr.PolicyType, pr.ViolationResult, pr.Score)
124+
}

0 commit comments

Comments
 (0)