diff --git a/ast.go b/ast.go index 7f03876..49d135f 100644 --- a/ast.go +++ b/ast.go @@ -2,7 +2,7 @@ package pgs import ( "github.com/golang/protobuf/protoc-gen-go/descriptor" - plugin_go "github.com/golang/protobuf/protoc-gen-go/plugin" + "github.com/golang/protobuf/protoc-gen-go/plugin" ) // AST encapsulates the entirety of the input CodeGeneratorRequest from protoc, @@ -26,9 +26,10 @@ type AST interface { type graph struct { d Debugger - targets map[string]File - packages map[string]Package - entities map[string]Entity + targets map[string]File + packages map[string]Package + entities map[string]Entity + extensions []Extension } func (g *graph) Targets() map[string]File { return g.targets } @@ -49,10 +50,11 @@ func ProcessDescriptors(debug Debugger, req *plugin_go.CodeGeneratorRequest) AST // connected AST entity graph. An error is returned if the input is malformed. func ProcessCodeGeneratorRequest(debug Debugger, req *plugin_go.CodeGeneratorRequest) AST { g := &graph{ - d: debug, - targets: make(map[string]File, len(req.GetFileToGenerate())), - packages: make(map[string]Package), - entities: make(map[string]Entity), + d: debug, + targets: make(map[string]File, len(req.GetFileToGenerate())), + packages: make(map[string]Package), + entities: make(map[string]Entity), + extensions: []Extension{}, } for _, f := range req.GetFileToGenerate() { @@ -64,6 +66,15 @@ func ProcessCodeGeneratorRequest(debug Debugger, req *plugin_go.CodeGeneratorReq pkg.addFile(g.hydrateFile(pkg, f)) } + for _, e := range g.extensions { + e.addType(g.hydrateFieldType(e)) + extendee := g.mustSeen(e.Descriptor().GetExtendee()).(Message) + e.setExtendee(extendee) + if extendee != nil { + extendee.addExtension(e) + } + } + return g } @@ -111,6 +122,13 @@ func (g *graph) hydrateFile(pkg Package, f *descriptor.FileDescriptorProto) File fl.addEnum(g.hydrateEnum(fl, e)) } + exts := f.GetExtension() + fl.defExts = make([]Extension, 0, len(exts)) + for _, ext := range exts { + e := g.hydrateExtension(fl, ext) + fl.addDefExtension(e) + } + msgs := f.GetMessageType() fl.msgs = make([]Message, 0, len(f.GetMessageType())) for _, msg := range msgs { @@ -250,6 +268,13 @@ func (g *graph) hydrateMessage(p ParentEntity, md *descriptor.DescriptorProto) M } } + exts := md.GetExtension() + m.defExts = make([]Extension, 0, len(exts)) + for _, ext := range md.GetExtension() { + e := g.hydrateExtension(m, ext) + m.addDefExtension(e) + } + return m } @@ -273,6 +298,17 @@ func (g *graph) hydrateOneOf(m Message, od *descriptor.OneofDescriptorProto) One return o } +func (g *graph) hydrateExtension(parent ParentEntity, fd *descriptor.FieldDescriptorProto) Extension { + ext := &ext{ + parent: parent, + } + ext.desc = fd + g.add(ext) + g.extensions = append(g.extensions, ext) + + return ext +} + func (g *graph) hydrateFieldType(fld Field) FieldType { s := &scalarT{fld: fld} diff --git a/ast_test.go b/ast_test.go index 5e72a5e..85ad367 100644 --- a/ast_test.go +++ b/ast_test.go @@ -309,3 +309,25 @@ func TestGraph_Packageless(t *testing.T) { }) } } + +func TestGraph_Extensions(t *testing.T) { + t.Parallel() + + g := buildGraph(t, "extensions") + assert.NotNil(t, g) + + ent, ok := g.Lookup("extensions/ext/data.proto") + assert.True(t, ok) + assert.NotNil(t, ent.(File).DefinedExtensions()) + assert.Len(t, ent.(File).DefinedExtensions(), 6) + + ent, ok = g.Lookup(".extensions.Request") + assert.True(t, ok) + assert.NotNil(t, ent.(Message).DefinedExtensions()) + assert.Len(t, ent.(Message).DefinedExtensions(), 1) + + ent, ok = g.Lookup(".google.protobuf.MessageOptions") + assert.True(t, ok) + assert.NotNil(t, ent.(Message).Extensions()) + assert.Len(t, ent.(Message).Extensions(), 1) +} diff --git a/entity.go b/entity.go index 4b2531c..3567a93 100644 --- a/entity.go +++ b/entity.go @@ -73,7 +73,11 @@ type ParentEntity interface { // AllEnums returns all top-level and nested enums from this entity. AllEnums() []Enum + // DefinedExtensions returns all Extensions defined on this entity. + DefinedExtensions() []Extension + addMessage(m Message) addMapEntry(m Message) addEnum(e Enum) + addDefExtension(e Extension) } diff --git a/extension.go b/extension.go index 01f8d17..99293eb 100644 --- a/extension.go +++ b/extension.go @@ -8,6 +8,50 @@ import ( "github.com/golang/protobuf/proto" ) +// An Extension is a custom option annotation that can be applied to an Entity to provide additional +// semantic details and metadata about the Entity. +type Extension interface { + Field + + // ParentEntity returns the ParentEntity where the Extension is defined + DefinedIn() ParentEntity + + // Extendee returns the Message that the Extension is extending + Extendee() Message + + setExtendee(m Message) +} + +type ext struct { + field + + parent ParentEntity + extendee Message +} + +func (e *ext) FullyQualifiedName() string { return fullyQualifiedName(e.parent, e) } +func (e *ext) Syntax() Syntax { return e.parent.Syntax() } +func (e *ext) Package() Package { return e.parent.Package() } +func (e *ext) File() File { return e.parent.File() } +func (e *ext) BuildTarget() bool { return e.parent.BuildTarget() } +func (e *ext) DefinedIn() ParentEntity { return e.parent } +func (e *ext) Extendee() Message { return e.extendee } +func (e *ext) Message() Message { return nil } +func (e *ext) InOneOf() bool { return false } +func (e *ext) OneOf() OneOf { return nil } +func (e *ext) setMessage(m Message) {} // noop +func (e *ext) setOneOf(o OneOf) {} // noop +func (e *ext) setExtendee(m Message) { e.extendee = m } + +func (e *ext) accept(v Visitor) (err error) { + if v == nil { + return + } + + _, err = v.VisitExtension(e) + return +} + var extractor extExtractor func init() { extractor = protoExtExtractor{} } diff --git a/extension_test.go b/extension_test.go index a84a62b..90bf354 100644 --- a/extension_test.go +++ b/extension_test.go @@ -5,10 +5,103 @@ import ( "errors" "testing" + "github.com/golang/protobuf/protoc-gen-go/descriptor" + "github.com/golang/protobuf/proto" "github.com/stretchr/testify/assert" ) +func TestExt_FullyQualifiedName(t *testing.T) { + t.Parallel() + + msg := dummyMsg() + e := &ext{parent: msg} + e.desc = &descriptor.FieldDescriptorProto{Name: proto.String("foo")} + assert.Equal(t, msg.FullyQualifiedName()+".foo", e.FullyQualifiedName()) +} + +func TestExt_Syntax(t *testing.T) { + t.Parallel() + + msg := dummyMsg() + e := &ext{parent: msg} + assert.Equal(t, msg.Syntax(), e.Syntax()) +} + +func TestExt_Package(t *testing.T) { + t.Parallel() + + msg := dummyMsg() + e := &ext{parent: msg} + assert.Equal(t, msg.Package(), e.Package()) +} + +func TestExt_File(t *testing.T) { + t.Parallel() + + msg := dummyMsg() + e := &ext{parent: msg} + assert.Equal(t, msg.File(), e.File()) +} + +func TestExt_BuildTarget(t *testing.T) { + t.Parallel() + + msg := dummyMsg() + e := &ext{parent: msg} + assert.Equal(t, msg.BuildTarget(), e.BuildTarget()) +} + +func TestExt_ParentEntity(t *testing.T) { + t.Parallel() + + msg := dummyMsg() + e := &ext{parent: msg} + assert.Equal(t, msg, e.DefinedIn()) +} + +func TestExt_Extendee(t *testing.T) { + t.Parallel() + + msg := dummyMsg() + e := &ext{} + e.setExtendee(msg) + assert.Equal(t, msg, e.Extendee()) +} + +func TestExt_Message(t *testing.T) { + t.Parallel() + + e := &ext{} + assert.Nil(t, e.Message()) +} + +func TestExt_InOneOf(t *testing.T) { + t.Parallel() + + e := &ext{} + assert.False(t, e.InOneOf()) +} + +func TestExt_OneOf(t *testing.T) { + t.Parallel() + + e := &ext{} + assert.Nil(t, e.OneOf()) +} + +func TestExt_Accept(t *testing.T) { + t.Parallel() + + e := &ext{} + + assert.NoError(t, e.accept(nil)) + + v := &mockVisitor{err: errors.New("")} + assert.Error(t, e.accept(v)) + assert.Equal(t, 1, v.extension) +} + type mockExtractor struct { has bool get interface{} diff --git a/file.go b/file.go index 9e38d19..695a651 100644 --- a/file.go +++ b/file.go @@ -38,6 +38,7 @@ type file struct { desc *descriptor.FileDescriptorProto pkg Package enums []Enum + defExts []Extension msgs []Message srvs []Service buildTarget bool @@ -114,6 +115,10 @@ func (f *file) Extension(desc *proto.ExtensionDesc, ext interface{}) (bool, erro return extension(f.desc.GetOptions(), desc, &ext) } +func (f *file) DefinedExtensions() []Extension { + return f.defExts +} + func (f *file) accept(v Visitor) (err error) { if v == nil { return nil @@ -144,6 +149,10 @@ func (f *file) accept(v Visitor) (err error) { return } +func (f *file) addDefExtension(ext Extension) { + f.defExts = append(f.defExts, ext) +} + func (f *file) setPackage(pkg Package) { f.pkg = pkg } func (f *file) addEnum(e Enum) { diff --git a/file_test.go b/file_test.go index 66060c0..e8ee60f 100644 --- a/file_test.go +++ b/file_test.go @@ -240,6 +240,17 @@ func TestFile_Extension(t *testing.T) { }) } +func TestFile_DefinedExtensions(t *testing.T) { + t.Parallel() + + f := &file{} + assert.Empty(t, f.DefinedExtensions()) + + ext := &ext{} + f.addDefExtension(ext) + assert.Len(t, f.DefinedExtensions(), 1) +} + // needed to wrap since there is a File method type mFile interface { File diff --git a/glide.lock b/glide.lock index 0579926..aaabce9 100644 --- a/glide.lock +++ b/glide.lock @@ -1,41 +1,40 @@ hash: a98be71ff763a2b4b183feef097ec7a862c3d26d01cb7c7484975b8e10dabeb5 -updated: 2018-05-09T12:54:16.045343-07:00 +updated: 2019-03-19T18:13:23.892274-07:00 imports: - name: github.com/golang/protobuf - version: b4deda0973fb4c70b50d226b1af49f3da59f5265 + version: b5d812f8a3706043e23a9cd5babf2e5423744d30 subpackages: + - descriptor - proto - protoc-gen-go/descriptor - protoc-gen-go/generator - protoc-gen-go/generator/internal/remap - protoc-gen-go/plugin + - ptypes/any - name: github.com/spf13/afero - version: 63644898a8da0bc22138abf860edaf5277b6102e + version: f4711e4db9e9a1d3887343acb72b2bbfc2f686f5 subpackages: - mem -- name: golang.org/x/net - version: 2fb46b16b8dda405028c50f7c7f0f9dd1fa6bfb1 - subpackages: - - context - name: golang.org/x/sync - version: 1d60e4601c6fd243af51cc01ddf169918a5407ca + version: 37e7f081c4d4c64e13b10787722085407fe5d15f subpackages: - errgroup - name: golang.org/x/text - version: e19ae1496984b1c655b8044a65c0300a3c878dd3 + version: e6919f6577db79269a6443b9dc46d18f2238fb5d subpackages: - transform - unicode/norm testImports: - name: github.com/davecgh/go-spew - version: 87df7c60d5820d0f8ae11afede5aa52325c09717 + version: d8f796af33cc11cb798c1aaeb27a4ebc5099927d subpackages: - spew - name: github.com/pmezard/go-difflib - version: 792786c7400a136282c1664665ae0a8db921c6c2 + version: 5d4384ee4fb2527b0a1256a821ebfc92f91efefc subpackages: - difflib - name: github.com/stretchr/testify - version: a726187e3128d0a0ec37f73ca7c4d3e508e6c2e5 + version: 363ebb24d041ccea8068222281c2e963e997b9dc subpackages: - assert + - require diff --git a/message.go b/message.go index 2e09ed6..f313756 100644 --- a/message.go +++ b/message.go @@ -32,6 +32,9 @@ type Message interface { // OneOfs returns the OneOfs contained within this Message. OneOfs() []OneOf + // Extensions returns all of the Extensions applied to this Message. + Extensions() []Extension + // IsMapEntry identifies this message as a MapEntry. If true, this message is // not generated as code, and is used exclusively when marshaling a map field // to the wire format. @@ -48,6 +51,7 @@ type Message interface { setParent(p ParentEntity) addField(f Field) + addExtension(e Extension) addOneOf(o OneOf) } @@ -57,6 +61,8 @@ type msg struct { msgs, preservedMsgs []Message enums []Enum + exts []Extension + defExts []Extension fields []Field oneofs []OneOf maps []Message @@ -142,6 +148,14 @@ func (m *msg) Extension(desc *proto.ExtensionDesc, ext interface{}) (bool, error return extension(m.desc.GetOptions(), desc, &ext) } +func (m *msg) Extensions() []Extension { + return m.exts +} + +func (m *msg) DefinedExtensions() []Extension { + return m.defExts +} + func (m *msg) accept(v Visitor) (err error) { if v == nil { return nil @@ -178,6 +192,14 @@ func (m *msg) accept(v Visitor) (err error) { return } +func (m *msg) addExtension(ext Extension) { + m.exts = append(m.exts, ext) +} + +func (m *msg) addDefExtension(ext Extension) { + m.defExts = append(m.defExts, ext) +} + func (m *msg) setParent(p ParentEntity) { m.parent = p } func (m *msg) addEnum(e Enum) { diff --git a/message_test.go b/message_test.go index b908c0a..7eaec80 100644 --- a/message_test.go +++ b/message_test.go @@ -220,6 +220,28 @@ func TestMsg_Extension(t *testing.T) { assert.NotPanics(t, func() { m.Extension(nil, nil) }) } +func TestMsg_Extensions(t *testing.T) { + t.Parallel() + + m := &msg{} + assert.Empty(t, m.Extensions()) + + ext := &ext{} + m.addExtension(ext) + assert.Len(t, m.Extensions(), 1) +} + +func TestMsg_DefinedExtensions(t *testing.T) { + t.Parallel() + + m := &msg{} + assert.Empty(t, m.DefinedExtensions()) + + ext := &ext{} + m.addDefExtension(ext) + assert.Len(t, m.DefinedExtensions(), 1) +} + func TestMsg_Accept(t *testing.T) { t.Parallel() diff --git a/node.go b/node.go index a8eccc8..cff6f58 100644 --- a/node.go +++ b/node.go @@ -17,6 +17,7 @@ type Visitor interface { VisitEnum(Enum) (v Visitor, err error) VisitEnumValue(EnumValue) (v Visitor, err error) VisitField(Field) (v Visitor, err error) + VisitExtension(Extension) (v Visitor, err error) VisitOneOf(OneOf) (v Visitor, err error) VisitService(Service) (v Visitor, err error) VisitMethod(Method) (v Visitor, err error) @@ -40,6 +41,7 @@ func (nv nilVisitor) VisitMessage(m Message) (v Visitor, err error) { return func (nv nilVisitor) VisitEnum(e Enum) (v Visitor, err error) { return nil, nil } func (nv nilVisitor) VisitEnumValue(e EnumValue) (v Visitor, err error) { return nil, nil } func (nv nilVisitor) VisitField(f Field) (v Visitor, err error) { return nil, nil } +func (nv nilVisitor) VisitExtension(e Extension) (v Visitor, err error) { return nil, nil } func (nv nilVisitor) VisitOneOf(o OneOf) (v Visitor, err error) { return nil, nil } func (nv nilVisitor) VisitService(s Service) (v Visitor, err error) { return nil, nil } func (nv nilVisitor) VisitMethod(m Method) (v Visitor, err error) { return nil, nil } @@ -63,6 +65,7 @@ func (pv passVisitor) VisitMessage(Message) (v Visitor, err error) { return func (pv passVisitor) VisitEnum(Enum) (v Visitor, err error) { return pv.v, nil } func (pv passVisitor) VisitEnumValue(EnumValue) (v Visitor, err error) { return pv.v, nil } func (pv passVisitor) VisitField(Field) (v Visitor, err error) { return pv.v, nil } +func (pv passVisitor) VisitExtension(Extension) (v Visitor, err error) { return pv.v, nil } func (pv passVisitor) VisitOneOf(OneOf) (v Visitor, err error) { return pv.v, nil } func (pv passVisitor) VisitService(Service) (v Visitor, err error) { return pv.v, nil } func (pv passVisitor) VisitMethod(Method) (v Visitor, err error) { return pv.v, nil } diff --git a/node_test.go b/node_test.go index 073ed3e..017bd63 100644 --- a/node_test.go +++ b/node_test.go @@ -68,6 +68,10 @@ func TestNilVisitor(t *testing.T) { assert.Nil(t, v) assert.NoError(t, err) + v, err = nv.VisitExtension(&ext{}) + assert.Nil(t, v) + assert.NoError(t, err) + v, err = nv.VisitOneOf(&oneof{}) assert.Nil(t, v) assert.NoError(t, err) @@ -111,6 +115,10 @@ func TestPassThroughVisitor(t *testing.T) { assert.Equal(t, nv, v) assert.NoError(t, err) + v, err = pv.VisitExtension(&ext{}) + assert.Equal(t, nv, v) + assert.NoError(t, err) + v, err = pv.VisitOneOf(&oneof{}) assert.Equal(t, nv, v) assert.NoError(t, err) @@ -120,7 +128,7 @@ type mockVisitor struct { v Visitor err error - pkg, file, message, enum, enumvalue, field, oneof, service, method int + pkg, file, message, enum, enumvalue, extension, field, oneof, service, method int } func (v *mockVisitor) VisitPackage(p Package) (w Visitor, err error) { @@ -153,6 +161,11 @@ func (v *mockVisitor) VisitField(f Field) (w Visitor, err error) { return v.v, v.err } +func (v *mockVisitor) VisitExtension(e Extension) (w Visitor, err error) { + v.extension++ + return v.v, v.err +} + func (v *mockVisitor) VisitOneOf(o OneOf) (w Visitor, err error) { v.oneof++ return v.v, v.err diff --git a/testdata/graph/extensions/everything.proto b/testdata/graph/extensions/everything.proto new file mode 100644 index 0000000..d817d24 --- /dev/null +++ b/testdata/graph/extensions/everything.proto @@ -0,0 +1,72 @@ + +syntax = "proto3"; +package extensions; +option go_package = "extensions"; +option (extensions.ext.owner) = "IDL Tools"; + +import "google/protobuf/descriptor.proto"; +import "google/protobuf/wrappers.proto"; +import "extensions/ext/api.proto"; +import "extensions/ext/data.proto"; + +message RootMessage { + option (extensions.ext.annotated) = true; + + message NestedMessage {} + + enum NestedEnum { + option (extensions.ext.ext) = {}; + + ZERO = 0 [(extensions.ext.numbers) = 1]; + ONE = 1; + TWO = 2; + } + + NestedMessage nested_msg = 1 [(extensions.ext.name) = "reflection"]; + NestedEnum nested_enum = 2; + + oneof union { + option (extensions.ext.float) = 5.67; + + bool boolean = 5; + string str = 6; + bytes data = 7; + } + + repeated NestedMessage rep_msg = 8; + repeated RootEnum rep_enum = 9; + repeated double rep_scalar = 10; + + map scalar_map = 11; + map recursive_map = 12; + map enum_map = 13; + + google.protobuf.StringValue wkt = 14; +} + +enum RootEnum { + ZERO = 0; + ONE = 1; + TWO = 2; +} + +message Request { + extend google.protobuf.FieldOptions { + string footer = 222333; + } +} + +message Response { +} + +service API { + option (extensions.ext.host) = "Alex Trebek"; + + rpc Do (Request) returns (Response) { + option (extensions.ext.header) = "X-Foo=BAR"; + } + + rpc Client (stream Request) returns (Response); + rpc Server (Request) returns (stream Response); + rpc BiDi (stream Request) returns (stream Response); +} diff --git a/testdata/graph/extensions/ext/api.proto b/testdata/graph/extensions/ext/api.proto new file mode 100644 index 0000000..10465c2 --- /dev/null +++ b/testdata/graph/extensions/ext/api.proto @@ -0,0 +1,13 @@ +syntax = "proto3"; +package extensions.ext; +option go_package = "ext"; + +import "google/protobuf/descriptor.proto"; + +extend google.protobuf.ServiceOptions { + string host = 111111; +} + +extend google.protobuf.MethodOptions { + string header = 222222; +} diff --git a/testdata/graph/extensions/ext/data.proto b/testdata/graph/extensions/ext/data.proto new file mode 100644 index 0000000..9f31f71 --- /dev/null +++ b/testdata/graph/extensions/ext/data.proto @@ -0,0 +1,31 @@ +syntax = "proto3"; +package extensions.ext; +option go_package = "ext"; + +import "google/protobuf/descriptor.proto"; + +extend google.protobuf.MessageOptions { + bool annotated = 123123; +} + +extend google.protobuf.FieldOptions { + string name = 456789; +} + +message EnumExtension {} + +extend google.protobuf.EnumOptions { + EnumExtension ext = 101112; +} + +extend google.protobuf.EnumValueOptions { + repeated int32 numbers = 131415; +} + +extend google.protobuf.OneofOptions { + double float = 161718; +} + +extend google.protobuf.FileOptions { + string owner = 192021; +}