Skip to content

Commit

Permalink
[IDL-324] Add extensions to PG* (lyft#52)
Browse files Browse the repository at this point in the history
* getting started

* add some tests

* addExtension method

* missing tests

* linting

* more stuff

* file/message/node tests and some AST stuff

* glide up

* tests for extension implementation

* glide up again

* feedback

* i broke everything :/

* y wont u hydrate :O

* slowly fix things

* fixed?

* testdata done

* tests

* assert len

* one more test
  • Loading branch information
alexkarim authored Mar 27, 2019
1 parent 2c275d2 commit 38e6c5c
Show file tree
Hide file tree
Showing 15 changed files with 415 additions and 21 deletions.
52 changes: 44 additions & 8 deletions ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package pgs

import (
"github.com/golang/protobuf/protoc-gen-go/descriptor"
plugin_go "github.com/golang/protobuf/protoc-gen-go/plugin"
"github.com/golang/protobuf/protoc-gen-go/plugin"
)

// AST encapsulates the entirety of the input CodeGeneratorRequest from protoc,
Expand All @@ -26,9 +26,10 @@ type AST interface {
type graph struct {
d Debugger

targets map[string]File
packages map[string]Package
entities map[string]Entity
targets map[string]File
packages map[string]Package
entities map[string]Entity
extensions []Extension
}

func (g *graph) Targets() map[string]File { return g.targets }
Expand All @@ -49,10 +50,11 @@ func ProcessDescriptors(debug Debugger, req *plugin_go.CodeGeneratorRequest) AST
// connected AST entity graph. An error is returned if the input is malformed.
func ProcessCodeGeneratorRequest(debug Debugger, req *plugin_go.CodeGeneratorRequest) AST {
g := &graph{
d: debug,
targets: make(map[string]File, len(req.GetFileToGenerate())),
packages: make(map[string]Package),
entities: make(map[string]Entity),
d: debug,
targets: make(map[string]File, len(req.GetFileToGenerate())),
packages: make(map[string]Package),
entities: make(map[string]Entity),
extensions: []Extension{},
}

for _, f := range req.GetFileToGenerate() {
Expand All @@ -64,6 +66,15 @@ func ProcessCodeGeneratorRequest(debug Debugger, req *plugin_go.CodeGeneratorReq
pkg.addFile(g.hydrateFile(pkg, f))
}

for _, e := range g.extensions {
e.addType(g.hydrateFieldType(e))
extendee := g.mustSeen(e.Descriptor().GetExtendee()).(Message)
e.setExtendee(extendee)
if extendee != nil {
extendee.addExtension(e)
}
}

return g
}

Expand Down Expand Up @@ -111,6 +122,13 @@ func (g *graph) hydrateFile(pkg Package, f *descriptor.FileDescriptorProto) File
fl.addEnum(g.hydrateEnum(fl, e))
}

exts := f.GetExtension()
fl.defExts = make([]Extension, 0, len(exts))
for _, ext := range exts {
e := g.hydrateExtension(fl, ext)
fl.addDefExtension(e)
}

msgs := f.GetMessageType()
fl.msgs = make([]Message, 0, len(f.GetMessageType()))
for _, msg := range msgs {
Expand Down Expand Up @@ -250,6 +268,13 @@ func (g *graph) hydrateMessage(p ParentEntity, md *descriptor.DescriptorProto) M
}
}

exts := md.GetExtension()
m.defExts = make([]Extension, 0, len(exts))
for _, ext := range md.GetExtension() {
e := g.hydrateExtension(m, ext)
m.addDefExtension(e)
}

return m
}

Expand All @@ -273,6 +298,17 @@ func (g *graph) hydrateOneOf(m Message, od *descriptor.OneofDescriptorProto) One
return o
}

func (g *graph) hydrateExtension(parent ParentEntity, fd *descriptor.FieldDescriptorProto) Extension {
ext := &ext{
parent: parent,
}
ext.desc = fd
g.add(ext)
g.extensions = append(g.extensions, ext)

return ext
}

func (g *graph) hydrateFieldType(fld Field) FieldType {
s := &scalarT{fld: fld}

Expand Down
22 changes: 22 additions & 0 deletions ast_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -309,3 +309,25 @@ func TestGraph_Packageless(t *testing.T) {
})
}
}

func TestGraph_Extensions(t *testing.T) {
t.Parallel()

g := buildGraph(t, "extensions")
assert.NotNil(t, g)

ent, ok := g.Lookup("extensions/ext/data.proto")
assert.True(t, ok)
assert.NotNil(t, ent.(File).DefinedExtensions())
assert.Len(t, ent.(File).DefinedExtensions(), 6)

ent, ok = g.Lookup(".extensions.Request")
assert.True(t, ok)
assert.NotNil(t, ent.(Message).DefinedExtensions())
assert.Len(t, ent.(Message).DefinedExtensions(), 1)

ent, ok = g.Lookup(".google.protobuf.MessageOptions")
assert.True(t, ok)
assert.NotNil(t, ent.(Message).Extensions())
assert.Len(t, ent.(Message).Extensions(), 1)
}
4 changes: 4 additions & 0 deletions entity.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,11 @@ type ParentEntity interface {
// AllEnums returns all top-level and nested enums from this entity.
AllEnums() []Enum

// DefinedExtensions returns all Extensions defined on this entity.
DefinedExtensions() []Extension

addMessage(m Message)
addMapEntry(m Message)
addEnum(e Enum)
addDefExtension(e Extension)
}
44 changes: 44 additions & 0 deletions extension.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,50 @@ import (
"github.com/golang/protobuf/proto"
)

// An Extension is a custom option annotation that can be applied to an Entity to provide additional
// semantic details and metadata about the Entity.
type Extension interface {
Field

// ParentEntity returns the ParentEntity where the Extension is defined
DefinedIn() ParentEntity

// Extendee returns the Message that the Extension is extending
Extendee() Message

setExtendee(m Message)
}

type ext struct {
field

parent ParentEntity
extendee Message
}

func (e *ext) FullyQualifiedName() string { return fullyQualifiedName(e.parent, e) }
func (e *ext) Syntax() Syntax { return e.parent.Syntax() }
func (e *ext) Package() Package { return e.parent.Package() }
func (e *ext) File() File { return e.parent.File() }
func (e *ext) BuildTarget() bool { return e.parent.BuildTarget() }
func (e *ext) DefinedIn() ParentEntity { return e.parent }
func (e *ext) Extendee() Message { return e.extendee }
func (e *ext) Message() Message { return nil }
func (e *ext) InOneOf() bool { return false }
func (e *ext) OneOf() OneOf { return nil }
func (e *ext) setMessage(m Message) {} // noop
func (e *ext) setOneOf(o OneOf) {} // noop
func (e *ext) setExtendee(m Message) { e.extendee = m }

func (e *ext) accept(v Visitor) (err error) {
if v == nil {
return
}

_, err = v.VisitExtension(e)
return
}

var extractor extExtractor

func init() { extractor = protoExtExtractor{} }
Expand Down
93 changes: 93 additions & 0 deletions extension_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,103 @@ import (
"errors"
"testing"

"github.com/golang/protobuf/protoc-gen-go/descriptor"

"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/assert"
)

func TestExt_FullyQualifiedName(t *testing.T) {
t.Parallel()

msg := dummyMsg()
e := &ext{parent: msg}
e.desc = &descriptor.FieldDescriptorProto{Name: proto.String("foo")}
assert.Equal(t, msg.FullyQualifiedName()+".foo", e.FullyQualifiedName())
}

func TestExt_Syntax(t *testing.T) {
t.Parallel()

msg := dummyMsg()
e := &ext{parent: msg}
assert.Equal(t, msg.Syntax(), e.Syntax())
}

func TestExt_Package(t *testing.T) {
t.Parallel()

msg := dummyMsg()
e := &ext{parent: msg}
assert.Equal(t, msg.Package(), e.Package())
}

func TestExt_File(t *testing.T) {
t.Parallel()

msg := dummyMsg()
e := &ext{parent: msg}
assert.Equal(t, msg.File(), e.File())
}

func TestExt_BuildTarget(t *testing.T) {
t.Parallel()

msg := dummyMsg()
e := &ext{parent: msg}
assert.Equal(t, msg.BuildTarget(), e.BuildTarget())
}

func TestExt_ParentEntity(t *testing.T) {
t.Parallel()

msg := dummyMsg()
e := &ext{parent: msg}
assert.Equal(t, msg, e.DefinedIn())
}

func TestExt_Extendee(t *testing.T) {
t.Parallel()

msg := dummyMsg()
e := &ext{}
e.setExtendee(msg)
assert.Equal(t, msg, e.Extendee())
}

func TestExt_Message(t *testing.T) {
t.Parallel()

e := &ext{}
assert.Nil(t, e.Message())
}

func TestExt_InOneOf(t *testing.T) {
t.Parallel()

e := &ext{}
assert.False(t, e.InOneOf())
}

func TestExt_OneOf(t *testing.T) {
t.Parallel()

e := &ext{}
assert.Nil(t, e.OneOf())
}

func TestExt_Accept(t *testing.T) {
t.Parallel()

e := &ext{}

assert.NoError(t, e.accept(nil))

v := &mockVisitor{err: errors.New("")}
assert.Error(t, e.accept(v))
assert.Equal(t, 1, v.extension)
}

type mockExtractor struct {
has bool
get interface{}
Expand Down
9 changes: 9 additions & 0 deletions file.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ type file struct {
desc *descriptor.FileDescriptorProto
pkg Package
enums []Enum
defExts []Extension
msgs []Message
srvs []Service
buildTarget bool
Expand Down Expand Up @@ -114,6 +115,10 @@ func (f *file) Extension(desc *proto.ExtensionDesc, ext interface{}) (bool, erro
return extension(f.desc.GetOptions(), desc, &ext)
}

func (f *file) DefinedExtensions() []Extension {
return f.defExts
}

func (f *file) accept(v Visitor) (err error) {
if v == nil {
return nil
Expand Down Expand Up @@ -144,6 +149,10 @@ func (f *file) accept(v Visitor) (err error) {
return
}

func (f *file) addDefExtension(ext Extension) {
f.defExts = append(f.defExts, ext)
}

func (f *file) setPackage(pkg Package) { f.pkg = pkg }

func (f *file) addEnum(e Enum) {
Expand Down
11 changes: 11 additions & 0 deletions file_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,17 @@ func TestFile_Extension(t *testing.T) {
})
}

func TestFile_DefinedExtensions(t *testing.T) {
t.Parallel()

f := &file{}
assert.Empty(t, f.DefinedExtensions())

ext := &ext{}
f.addDefExtension(ext)
assert.Len(t, f.DefinedExtensions(), 1)
}

// needed to wrap since there is a File method
type mFile interface {
File
Expand Down
23 changes: 11 additions & 12 deletions glide.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 38e6c5c

Please sign in to comment.