|
6 | 6 | package software.amazon.smithy.java.mcp.server;
|
7 | 7 |
|
8 | 8 | import java.util.ArrayList;
|
| 9 | +import java.util.HashMap; |
9 | 10 | import java.util.LinkedHashMap;
|
10 | 11 | import java.util.List;
|
11 | 12 | import java.util.Map;
|
12 |
| -import java.util.Set; |
13 | 13 | import software.amazon.smithy.ai.PromptTemplateDefinition;
|
14 | 14 | import software.amazon.smithy.ai.PromptsTrait;
|
| 15 | +import software.amazon.smithy.java.core.schema.Schema; |
| 16 | +import software.amazon.smithy.java.core.schema.TraitKey; |
15 | 17 | import software.amazon.smithy.java.mcp.model.PromptArgument;
|
16 | 18 | import software.amazon.smithy.java.mcp.model.PromptInfo;
|
17 |
| -import software.amazon.smithy.model.Model; |
18 |
| -import software.amazon.smithy.model.shapes.Shape; |
19 |
| -import software.amazon.smithy.model.shapes.ShapeId; |
20 |
| -import software.amazon.smithy.model.shapes.StructureShape; |
21 |
| -import software.amazon.smithy.model.traits.DocumentationTrait; |
22 |
| -import software.amazon.smithy.model.traits.RequiredTrait; |
| 19 | +import software.amazon.smithy.java.server.Service; |
23 | 20 | import software.amazon.smithy.utils.SmithyUnstableApi;
|
24 | 21 |
|
25 | 22 | /**
|
26 | 23 | * Handles loading and parsing of prompts from Smithy models.
|
27 | 24 | */
|
28 | 25 | @SmithyUnstableApi
|
29 |
| -public final class PromptLoader { |
| 26 | +final class PromptLoader { |
| 27 | + |
| 28 | + private static final TraitKey<PromptsTrait> PROMPTS_TRAIT_KEY = TraitKey.get(PromptsTrait.class); |
30 | 29 |
|
31 | 30 | public static final String TOOL_PREFERENCE_PREFIX = ".Tool preference: ";
|
32 | 31 |
|
33 | 32 | /**
|
34 | 33 | * Loads prompts from the provided Smithy models.
|
35 | 34 | *
|
36 |
| - * @param models List of Smithy models to extract prompts from |
37 | 35 | * @return Map of prompt names to PromptInfo objects
|
38 | 36 | */
|
39 |
| - public static Map<String, PromptInfo> loadPrompts(List<Model> models) { |
| 37 | + public static Map<String, PromptInfo> loadPrompts(List<Service> services) { |
40 | 38 | Map<String, PromptInfo> promptInfos = new LinkedHashMap<>();
|
41 | 39 |
|
42 |
| - for (Model model : models) { |
43 |
| - Set<Shape> promptShapes = model.getShapesWithTrait(PromptsTrait.ID); |
44 |
| - for (Shape prompt : promptShapes) { |
45 |
| - |
46 |
| - Map<String, PromptTemplateDefinition> promptDefinitions = |
47 |
| - prompt.expectTrait(PromptsTrait.class).getValues(); |
48 |
| - for (Map.Entry<String, PromptTemplateDefinition> entry : promptDefinitions.entrySet()) { |
49 |
| - var promptName = entry.getKey().toLowerCase(); |
50 |
| - var promptTemplateDefinition = entry.getValue(); |
51 |
| - var templateString = promptTemplateDefinition.getTemplate(); |
52 |
| - |
53 |
| - promptInfos.put( |
54 |
| - promptName, |
55 |
| - PromptInfo |
56 |
| - .builder() |
57 |
| - .name(promptName) |
58 |
| - .description(promptTemplateDefinition.getDescription()) |
59 |
| - .template( |
60 |
| - promptTemplateDefinition.getPreferWhen().isPresent() |
61 |
| - ? templateString + TOOL_PREFERENCE_PREFIX |
62 |
| - + promptTemplateDefinition.getPreferWhen().get() |
63 |
| - : templateString) |
64 |
| - .arguments(promptTemplateDefinition.getArguments().isPresent() |
65 |
| - ? convertArgumentShapeToPromptArgument(model, |
66 |
| - promptTemplateDefinition.getArguments().get()) |
67 |
| - : List.of()) |
68 |
| - .build()); |
| 40 | + for (var service : services) { |
| 41 | + Map<String, PromptTemplateDefinition> promptDefinitions = new HashMap<>(); |
| 42 | + var servicePromptTrait = service.schema().getTrait(PROMPTS_TRAIT_KEY); |
| 43 | + if (servicePromptTrait != null) { |
| 44 | + promptDefinitions.putAll(servicePromptTrait.getValues()); |
| 45 | + } |
| 46 | + service.getAllOperations().forEach(operation -> { |
| 47 | + var operationPromptsTrait = operation.getApiOperation().schema().getTrait(PROMPTS_TRAIT_KEY); |
| 48 | + if (operationPromptsTrait != null) { |
| 49 | + promptDefinitions.putAll(operationPromptsTrait.getValues()); |
69 | 50 | }
|
| 51 | + |
| 52 | + }); |
| 53 | + for (Map.Entry<String, PromptTemplateDefinition> entry : promptDefinitions.entrySet()) { |
| 54 | + var promptName = entry.getKey().toLowerCase(); |
| 55 | + var promptTemplateDefinition = entry.getValue(); |
| 56 | + var templateString = promptTemplateDefinition.getTemplate(); |
| 57 | + |
| 58 | + promptInfos.put( |
| 59 | + promptName, |
| 60 | + PromptInfo |
| 61 | + .builder() |
| 62 | + .name(promptName) |
| 63 | + .description(promptTemplateDefinition.getDescription()) |
| 64 | + .template( |
| 65 | + promptTemplateDefinition.getPreferWhen().isPresent() |
| 66 | + ? templateString + TOOL_PREFERENCE_PREFIX |
| 67 | + + promptTemplateDefinition.getPreferWhen().get() |
| 68 | + : templateString) |
| 69 | + .arguments(promptTemplateDefinition.getArguments().isPresent() |
| 70 | + ? convertArgumentShapeToPromptArgument( |
| 71 | + service.schemaIndex() |
| 72 | + .getSchema(promptTemplateDefinition.getArguments().get())) |
| 73 | + : List.of()) |
| 74 | + .build()); |
70 | 75 | }
|
71 | 76 | }
|
72 |
| - |
73 | 77 | return promptInfos;
|
74 | 78 | }
|
75 | 79 |
|
76 | 80 | /**
|
77 | 81 | * Converts a Smithy structure shape to a list of PromptArgument objects.
|
78 | 82 | *
|
79 |
| - * @param model The Smithy model containing the shape |
80 | 83 | * @param argumentShapeId The ShapeId of the structure to convert
|
81 | 84 | * @return List of PromptArgument objects representing the structure members
|
82 | 85 | */
|
83 |
| - public static List<PromptArgument> convertArgumentShapeToPromptArgument(Model model, ShapeId argumentShapeId) { |
84 |
| - StructureShape argument = model.expectShape(argumentShapeId, StructureShape.class); |
| 86 | + public static List<PromptArgument> convertArgumentShapeToPromptArgument(Schema argument) { |
85 | 87 | List<PromptArgument> promptArguments = new ArrayList<>();
|
86 | 88 |
|
87 |
| - for (var member : argument.getAllMembers().entrySet()) { |
88 |
| - String memberName = member.getKey(); |
89 |
| - var memberShape = member.getValue(); |
| 89 | + for (var member : argument.members()) { |
| 90 | + String memberName = member.memberName(); |
90 | 91 |
|
91 | 92 | // Get description from documentation trait, use empty string if not present
|
92 | 93 | String description = "";
|
93 |
| - var documentationTrait = memberShape.getTrait(DocumentationTrait.class); |
94 |
| - if (documentationTrait.isPresent()) { |
95 |
| - description = documentationTrait.get().getValue(); |
| 94 | + var documentationTrait = member.getTrait(TraitKey.DOCUMENTATION_TRAIT); |
| 95 | + if (documentationTrait != null) { |
| 96 | + description = documentationTrait.getValue(); |
96 | 97 | }
|
97 | 98 |
|
98 | 99 | // Check if member is required
|
99 |
| - boolean isRequired = memberShape.hasTrait(RequiredTrait.class); |
| 100 | + boolean isRequired = member.getTrait(TraitKey.REQUIRED_TRAIT) != null; |
100 | 101 |
|
101 | 102 | // Build the PromptArgument
|
102 | 103 | PromptArgument promptArgument = PromptArgument.builder()
|
|
0 commit comments