forked from mirrors/gotosocial
870 lines
24 KiB
Go
870 lines
24 KiB
Go
|
// Copyright 2018 The Go Authors. All rights reserved.
|
||
|
// Use of this source code is governed by a BSD-style
|
||
|
// license that can be found in the LICENSE file.
|
||
|
|
||
|
// Package protoregistry provides data structures to register and lookup
|
||
|
// protobuf descriptor types.
|
||
|
//
|
||
|
// The Files registry contains file descriptors and provides the ability
|
||
|
// to iterate over the files or lookup a specific descriptor within the files.
|
||
|
// Files only contains protobuf descriptors and has no understanding of Go
|
||
|
// type information that may be associated with each descriptor.
|
||
|
//
|
||
|
// The Types registry contains descriptor types for which there is a known
|
||
|
// Go type associated with that descriptor. It provides the ability to iterate
|
||
|
// over the registered types or lookup a type by name.
|
||
|
package protoregistry
|
||
|
|
||
|
import (
|
||
|
"fmt"
|
||
|
"os"
|
||
|
"strings"
|
||
|
"sync"
|
||
|
|
||
|
"google.golang.org/protobuf/internal/encoding/messageset"
|
||
|
"google.golang.org/protobuf/internal/errors"
|
||
|
"google.golang.org/protobuf/internal/flags"
|
||
|
"google.golang.org/protobuf/reflect/protoreflect"
|
||
|
)
|
||
|
|
||
|
// conflictPolicy configures the policy for handling registration conflicts.
|
||
|
//
|
||
|
// It can be over-written at compile time with a linker-initialized variable:
|
||
|
// go build -ldflags "-X google.golang.org/protobuf/reflect/protoregistry.conflictPolicy=warn"
|
||
|
//
|
||
|
// It can be over-written at program execution with an environment variable:
|
||
|
// GOLANG_PROTOBUF_REGISTRATION_CONFLICT=warn ./main
|
||
|
//
|
||
|
// Neither of the above are covered by the compatibility promise and
|
||
|
// may be removed in a future release of this module.
|
||
|
var conflictPolicy = "panic" // "panic" | "warn" | "ignore"
|
||
|
|
||
|
// ignoreConflict reports whether to ignore a registration conflict
|
||
|
// given the descriptor being registered and the error.
|
||
|
// It is a variable so that the behavior is easily overridden in another file.
|
||
|
var ignoreConflict = func(d protoreflect.Descriptor, err error) bool {
|
||
|
const env = "GOLANG_PROTOBUF_REGISTRATION_CONFLICT"
|
||
|
const faq = "https://developers.google.com/protocol-buffers/docs/reference/go/faq#namespace-conflict"
|
||
|
policy := conflictPolicy
|
||
|
if v := os.Getenv(env); v != "" {
|
||
|
policy = v
|
||
|
}
|
||
|
switch policy {
|
||
|
case "panic":
|
||
|
panic(fmt.Sprintf("%v\nSee %v\n", err, faq))
|
||
|
case "warn":
|
||
|
fmt.Fprintf(os.Stderr, "WARNING: %v\nSee %v\n\n", err, faq)
|
||
|
return true
|
||
|
case "ignore":
|
||
|
return true
|
||
|
default:
|
||
|
panic("invalid " + env + " value: " + os.Getenv(env))
|
||
|
}
|
||
|
}
|
||
|
|
||
|
var globalMutex sync.RWMutex
|
||
|
|
||
|
// GlobalFiles is a global registry of file descriptors.
|
||
|
var GlobalFiles *Files = new(Files)
|
||
|
|
||
|
// GlobalTypes is the registry used by default for type lookups
|
||
|
// unless a local registry is provided by the user.
|
||
|
var GlobalTypes *Types = new(Types)
|
||
|
|
||
|
// NotFound is a sentinel error value to indicate that the type was not found.
|
||
|
//
|
||
|
// Since registry lookup can happen in the critical performance path, resolvers
|
||
|
// must return this exact error value, not an error wrapping it.
|
||
|
var NotFound = errors.New("not found")
|
||
|
|
||
|
// Files is a registry for looking up or iterating over files and the
|
||
|
// descriptors contained within them.
|
||
|
// The Find and Range methods are safe for concurrent use.
|
||
|
type Files struct {
|
||
|
// The map of descsByName contains:
|
||
|
// EnumDescriptor
|
||
|
// EnumValueDescriptor
|
||
|
// MessageDescriptor
|
||
|
// ExtensionDescriptor
|
||
|
// ServiceDescriptor
|
||
|
// *packageDescriptor
|
||
|
//
|
||
|
// Note that files are stored as a slice, since a package may contain
|
||
|
// multiple files. Only top-level declarations are registered.
|
||
|
// Note that enum values are in the top-level since that are in the same
|
||
|
// scope as the parent enum.
|
||
|
descsByName map[protoreflect.FullName]interface{}
|
||
|
filesByPath map[string]protoreflect.FileDescriptor
|
||
|
}
|
||
|
|
||
|
type packageDescriptor struct {
|
||
|
files []protoreflect.FileDescriptor
|
||
|
}
|
||
|
|
||
|
// RegisterFile registers the provided file descriptor.
|
||
|
//
|
||
|
// If any descriptor within the file conflicts with the descriptor of any
|
||
|
// previously registered file (e.g., two enums with the same full name),
|
||
|
// then the file is not registered and an error is returned.
|
||
|
//
|
||
|
// It is permitted for multiple files to have the same file path.
|
||
|
func (r *Files) RegisterFile(file protoreflect.FileDescriptor) error {
|
||
|
if r == GlobalFiles {
|
||
|
globalMutex.Lock()
|
||
|
defer globalMutex.Unlock()
|
||
|
}
|
||
|
if r.descsByName == nil {
|
||
|
r.descsByName = map[protoreflect.FullName]interface{}{
|
||
|
"": &packageDescriptor{},
|
||
|
}
|
||
|
r.filesByPath = make(map[string]protoreflect.FileDescriptor)
|
||
|
}
|
||
|
path := file.Path()
|
||
|
if prev := r.filesByPath[path]; prev != nil {
|
||
|
r.checkGenProtoConflict(path)
|
||
|
err := errors.New("file %q is already registered", file.Path())
|
||
|
err = amendErrorWithCaller(err, prev, file)
|
||
|
if r == GlobalFiles && ignoreConflict(file, err) {
|
||
|
err = nil
|
||
|
}
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
for name := file.Package(); name != ""; name = name.Parent() {
|
||
|
switch prev := r.descsByName[name]; prev.(type) {
|
||
|
case nil, *packageDescriptor:
|
||
|
default:
|
||
|
err := errors.New("file %q has a package name conflict over %v", file.Path(), name)
|
||
|
err = amendErrorWithCaller(err, prev, file)
|
||
|
if r == GlobalFiles && ignoreConflict(file, err) {
|
||
|
err = nil
|
||
|
}
|
||
|
return err
|
||
|
}
|
||
|
}
|
||
|
var err error
|
||
|
var hasConflict bool
|
||
|
rangeTopLevelDescriptors(file, func(d protoreflect.Descriptor) {
|
||
|
if prev := r.descsByName[d.FullName()]; prev != nil {
|
||
|
hasConflict = true
|
||
|
err = errors.New("file %q has a name conflict over %v", file.Path(), d.FullName())
|
||
|
err = amendErrorWithCaller(err, prev, file)
|
||
|
if r == GlobalFiles && ignoreConflict(d, err) {
|
||
|
err = nil
|
||
|
}
|
||
|
}
|
||
|
})
|
||
|
if hasConflict {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
for name := file.Package(); name != ""; name = name.Parent() {
|
||
|
if r.descsByName[name] == nil {
|
||
|
r.descsByName[name] = &packageDescriptor{}
|
||
|
}
|
||
|
}
|
||
|
p := r.descsByName[file.Package()].(*packageDescriptor)
|
||
|
p.files = append(p.files, file)
|
||
|
rangeTopLevelDescriptors(file, func(d protoreflect.Descriptor) {
|
||
|
r.descsByName[d.FullName()] = d
|
||
|
})
|
||
|
r.filesByPath[path] = file
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// Several well-known types were hosted in the google.golang.org/genproto module
|
||
|
// but were later moved to this module. To avoid a weak dependency on the
|
||
|
// genproto module (and its relatively large set of transitive dependencies),
|
||
|
// we rely on a registration conflict to determine whether the genproto version
|
||
|
// is too old (i.e., does not contain aliases to the new type declarations).
|
||
|
func (r *Files) checkGenProtoConflict(path string) {
|
||
|
if r != GlobalFiles {
|
||
|
return
|
||
|
}
|
||
|
var prevPath string
|
||
|
const prevModule = "google.golang.org/genproto"
|
||
|
const prevVersion = "cb27e3aa (May 26th, 2020)"
|
||
|
switch path {
|
||
|
case "google/protobuf/field_mask.proto":
|
||
|
prevPath = prevModule + "/protobuf/field_mask"
|
||
|
case "google/protobuf/api.proto":
|
||
|
prevPath = prevModule + "/protobuf/api"
|
||
|
case "google/protobuf/type.proto":
|
||
|
prevPath = prevModule + "/protobuf/ptype"
|
||
|
case "google/protobuf/source_context.proto":
|
||
|
prevPath = prevModule + "/protobuf/source_context"
|
||
|
default:
|
||
|
return
|
||
|
}
|
||
|
pkgName := strings.TrimSuffix(strings.TrimPrefix(path, "google/protobuf/"), ".proto")
|
||
|
pkgName = strings.Replace(pkgName, "_", "", -1) + "pb" // e.g., "field_mask" => "fieldmaskpb"
|
||
|
currPath := "google.golang.org/protobuf/types/known/" + pkgName
|
||
|
panic(fmt.Sprintf(""+
|
||
|
"duplicate registration of %q\n"+
|
||
|
"\n"+
|
||
|
"The generated definition for this file has moved:\n"+
|
||
|
"\tfrom: %q\n"+
|
||
|
"\tto: %q\n"+
|
||
|
"A dependency on the %q module must\n"+
|
||
|
"be at version %v or higher.\n"+
|
||
|
"\n"+
|
||
|
"Upgrade the dependency by running:\n"+
|
||
|
"\tgo get -u %v\n",
|
||
|
path, prevPath, currPath, prevModule, prevVersion, prevPath))
|
||
|
}
|
||
|
|
||
|
// FindDescriptorByName looks up a descriptor by the full name.
|
||
|
//
|
||
|
// This returns (nil, NotFound) if not found.
|
||
|
func (r *Files) FindDescriptorByName(name protoreflect.FullName) (protoreflect.Descriptor, error) {
|
||
|
if r == nil {
|
||
|
return nil, NotFound
|
||
|
}
|
||
|
if r == GlobalFiles {
|
||
|
globalMutex.RLock()
|
||
|
defer globalMutex.RUnlock()
|
||
|
}
|
||
|
prefix := name
|
||
|
suffix := nameSuffix("")
|
||
|
for prefix != "" {
|
||
|
if d, ok := r.descsByName[prefix]; ok {
|
||
|
switch d := d.(type) {
|
||
|
case protoreflect.EnumDescriptor:
|
||
|
if d.FullName() == name {
|
||
|
return d, nil
|
||
|
}
|
||
|
case protoreflect.EnumValueDescriptor:
|
||
|
if d.FullName() == name {
|
||
|
return d, nil
|
||
|
}
|
||
|
case protoreflect.MessageDescriptor:
|
||
|
if d.FullName() == name {
|
||
|
return d, nil
|
||
|
}
|
||
|
if d := findDescriptorInMessage(d, suffix); d != nil && d.FullName() == name {
|
||
|
return d, nil
|
||
|
}
|
||
|
case protoreflect.ExtensionDescriptor:
|
||
|
if d.FullName() == name {
|
||
|
return d, nil
|
||
|
}
|
||
|
case protoreflect.ServiceDescriptor:
|
||
|
if d.FullName() == name {
|
||
|
return d, nil
|
||
|
}
|
||
|
if d := d.Methods().ByName(suffix.Pop()); d != nil && d.FullName() == name {
|
||
|
return d, nil
|
||
|
}
|
||
|
}
|
||
|
return nil, NotFound
|
||
|
}
|
||
|
prefix = prefix.Parent()
|
||
|
suffix = nameSuffix(name[len(prefix)+len("."):])
|
||
|
}
|
||
|
return nil, NotFound
|
||
|
}
|
||
|
|
||
|
func findDescriptorInMessage(md protoreflect.MessageDescriptor, suffix nameSuffix) protoreflect.Descriptor {
|
||
|
name := suffix.Pop()
|
||
|
if suffix == "" {
|
||
|
if ed := md.Enums().ByName(name); ed != nil {
|
||
|
return ed
|
||
|
}
|
||
|
for i := md.Enums().Len() - 1; i >= 0; i-- {
|
||
|
if vd := md.Enums().Get(i).Values().ByName(name); vd != nil {
|
||
|
return vd
|
||
|
}
|
||
|
}
|
||
|
if xd := md.Extensions().ByName(name); xd != nil {
|
||
|
return xd
|
||
|
}
|
||
|
if fd := md.Fields().ByName(name); fd != nil {
|
||
|
return fd
|
||
|
}
|
||
|
if od := md.Oneofs().ByName(name); od != nil {
|
||
|
return od
|
||
|
}
|
||
|
}
|
||
|
if md := md.Messages().ByName(name); md != nil {
|
||
|
if suffix == "" {
|
||
|
return md
|
||
|
}
|
||
|
return findDescriptorInMessage(md, suffix)
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
type nameSuffix string
|
||
|
|
||
|
func (s *nameSuffix) Pop() (name protoreflect.Name) {
|
||
|
if i := strings.IndexByte(string(*s), '.'); i >= 0 {
|
||
|
name, *s = protoreflect.Name((*s)[:i]), (*s)[i+1:]
|
||
|
} else {
|
||
|
name, *s = protoreflect.Name((*s)), ""
|
||
|
}
|
||
|
return name
|
||
|
}
|
||
|
|
||
|
// FindFileByPath looks up a file by the path.
|
||
|
//
|
||
|
// This returns (nil, NotFound) if not found.
|
||
|
func (r *Files) FindFileByPath(path string) (protoreflect.FileDescriptor, error) {
|
||
|
if r == nil {
|
||
|
return nil, NotFound
|
||
|
}
|
||
|
if r == GlobalFiles {
|
||
|
globalMutex.RLock()
|
||
|
defer globalMutex.RUnlock()
|
||
|
}
|
||
|
if fd, ok := r.filesByPath[path]; ok {
|
||
|
return fd, nil
|
||
|
}
|
||
|
return nil, NotFound
|
||
|
}
|
||
|
|
||
|
// NumFiles reports the number of registered files.
|
||
|
func (r *Files) NumFiles() int {
|
||
|
if r == nil {
|
||
|
return 0
|
||
|
}
|
||
|
if r == GlobalFiles {
|
||
|
globalMutex.RLock()
|
||
|
defer globalMutex.RUnlock()
|
||
|
}
|
||
|
return len(r.filesByPath)
|
||
|
}
|
||
|
|
||
|
// RangeFiles iterates over all registered files while f returns true.
|
||
|
// The iteration order is undefined.
|
||
|
func (r *Files) RangeFiles(f func(protoreflect.FileDescriptor) bool) {
|
||
|
if r == nil {
|
||
|
return
|
||
|
}
|
||
|
if r == GlobalFiles {
|
||
|
globalMutex.RLock()
|
||
|
defer globalMutex.RUnlock()
|
||
|
}
|
||
|
for _, file := range r.filesByPath {
|
||
|
if !f(file) {
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// NumFilesByPackage reports the number of registered files in a proto package.
|
||
|
func (r *Files) NumFilesByPackage(name protoreflect.FullName) int {
|
||
|
if r == nil {
|
||
|
return 0
|
||
|
}
|
||
|
if r == GlobalFiles {
|
||
|
globalMutex.RLock()
|
||
|
defer globalMutex.RUnlock()
|
||
|
}
|
||
|
p, ok := r.descsByName[name].(*packageDescriptor)
|
||
|
if !ok {
|
||
|
return 0
|
||
|
}
|
||
|
return len(p.files)
|
||
|
}
|
||
|
|
||
|
// RangeFilesByPackage iterates over all registered files in a given proto package
|
||
|
// while f returns true. The iteration order is undefined.
|
||
|
func (r *Files) RangeFilesByPackage(name protoreflect.FullName, f func(protoreflect.FileDescriptor) bool) {
|
||
|
if r == nil {
|
||
|
return
|
||
|
}
|
||
|
if r == GlobalFiles {
|
||
|
globalMutex.RLock()
|
||
|
defer globalMutex.RUnlock()
|
||
|
}
|
||
|
p, ok := r.descsByName[name].(*packageDescriptor)
|
||
|
if !ok {
|
||
|
return
|
||
|
}
|
||
|
for _, file := range p.files {
|
||
|
if !f(file) {
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// rangeTopLevelDescriptors iterates over all top-level descriptors in a file
|
||
|
// which will be directly entered into the registry.
|
||
|
func rangeTopLevelDescriptors(fd protoreflect.FileDescriptor, f func(protoreflect.Descriptor)) {
|
||
|
eds := fd.Enums()
|
||
|
for i := eds.Len() - 1; i >= 0; i-- {
|
||
|
f(eds.Get(i))
|
||
|
vds := eds.Get(i).Values()
|
||
|
for i := vds.Len() - 1; i >= 0; i-- {
|
||
|
f(vds.Get(i))
|
||
|
}
|
||
|
}
|
||
|
mds := fd.Messages()
|
||
|
for i := mds.Len() - 1; i >= 0; i-- {
|
||
|
f(mds.Get(i))
|
||
|
}
|
||
|
xds := fd.Extensions()
|
||
|
for i := xds.Len() - 1; i >= 0; i-- {
|
||
|
f(xds.Get(i))
|
||
|
}
|
||
|
sds := fd.Services()
|
||
|
for i := sds.Len() - 1; i >= 0; i-- {
|
||
|
f(sds.Get(i))
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// MessageTypeResolver is an interface for looking up messages.
|
||
|
//
|
||
|
// A compliant implementation must deterministically return the same type
|
||
|
// if no error is encountered.
|
||
|
//
|
||
|
// The Types type implements this interface.
|
||
|
type MessageTypeResolver interface {
|
||
|
// FindMessageByName looks up a message by its full name.
|
||
|
// E.g., "google.protobuf.Any"
|
||
|
//
|
||
|
// This return (nil, NotFound) if not found.
|
||
|
FindMessageByName(message protoreflect.FullName) (protoreflect.MessageType, error)
|
||
|
|
||
|
// FindMessageByURL looks up a message by a URL identifier.
|
||
|
// See documentation on google.protobuf.Any.type_url for the URL format.
|
||
|
//
|
||
|
// This returns (nil, NotFound) if not found.
|
||
|
FindMessageByURL(url string) (protoreflect.MessageType, error)
|
||
|
}
|
||
|
|
||
|
// ExtensionTypeResolver is an interface for looking up extensions.
|
||
|
//
|
||
|
// A compliant implementation must deterministically return the same type
|
||
|
// if no error is encountered.
|
||
|
//
|
||
|
// The Types type implements this interface.
|
||
|
type ExtensionTypeResolver interface {
|
||
|
// FindExtensionByName looks up a extension field by the field's full name.
|
||
|
// Note that this is the full name of the field as determined by
|
||
|
// where the extension is declared and is unrelated to the full name of the
|
||
|
// message being extended.
|
||
|
//
|
||
|
// This returns (nil, NotFound) if not found.
|
||
|
FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error)
|
||
|
|
||
|
// FindExtensionByNumber looks up a extension field by the field number
|
||
|
// within some parent message, identified by full name.
|
||
|
//
|
||
|
// This returns (nil, NotFound) if not found.
|
||
|
FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error)
|
||
|
}
|
||
|
|
||
|
var (
|
||
|
_ MessageTypeResolver = (*Types)(nil)
|
||
|
_ ExtensionTypeResolver = (*Types)(nil)
|
||
|
)
|
||
|
|
||
|
// Types is a registry for looking up or iterating over descriptor types.
|
||
|
// The Find and Range methods are safe for concurrent use.
|
||
|
type Types struct {
|
||
|
typesByName typesByName
|
||
|
extensionsByMessage extensionsByMessage
|
||
|
|
||
|
numEnums int
|
||
|
numMessages int
|
||
|
numExtensions int
|
||
|
}
|
||
|
|
||
|
type (
|
||
|
typesByName map[protoreflect.FullName]interface{}
|
||
|
extensionsByMessage map[protoreflect.FullName]extensionsByNumber
|
||
|
extensionsByNumber map[protoreflect.FieldNumber]protoreflect.ExtensionType
|
||
|
)
|
||
|
|
||
|
// RegisterMessage registers the provided message type.
|
||
|
//
|
||
|
// If a naming conflict occurs, the type is not registered and an error is returned.
|
||
|
func (r *Types) RegisterMessage(mt protoreflect.MessageType) error {
|
||
|
// Under rare circumstances getting the descriptor might recursively
|
||
|
// examine the registry, so fetch it before locking.
|
||
|
md := mt.Descriptor()
|
||
|
|
||
|
if r == GlobalTypes {
|
||
|
globalMutex.Lock()
|
||
|
defer globalMutex.Unlock()
|
||
|
}
|
||
|
|
||
|
if err := r.register("message", md, mt); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
r.numMessages++
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// RegisterEnum registers the provided enum type.
|
||
|
//
|
||
|
// If a naming conflict occurs, the type is not registered and an error is returned.
|
||
|
func (r *Types) RegisterEnum(et protoreflect.EnumType) error {
|
||
|
// Under rare circumstances getting the descriptor might recursively
|
||
|
// examine the registry, so fetch it before locking.
|
||
|
ed := et.Descriptor()
|
||
|
|
||
|
if r == GlobalTypes {
|
||
|
globalMutex.Lock()
|
||
|
defer globalMutex.Unlock()
|
||
|
}
|
||
|
|
||
|
if err := r.register("enum", ed, et); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
r.numEnums++
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// RegisterExtension registers the provided extension type.
|
||
|
//
|
||
|
// If a naming conflict occurs, the type is not registered and an error is returned.
|
||
|
func (r *Types) RegisterExtension(xt protoreflect.ExtensionType) error {
|
||
|
// Under rare circumstances getting the descriptor might recursively
|
||
|
// examine the registry, so fetch it before locking.
|
||
|
//
|
||
|
// A known case where this can happen: Fetching the TypeDescriptor for a
|
||
|
// legacy ExtensionDesc can consult the global registry.
|
||
|
xd := xt.TypeDescriptor()
|
||
|
|
||
|
if r == GlobalTypes {
|
||
|
globalMutex.Lock()
|
||
|
defer globalMutex.Unlock()
|
||
|
}
|
||
|
|
||
|
field := xd.Number()
|
||
|
message := xd.ContainingMessage().FullName()
|
||
|
if prev := r.extensionsByMessage[message][field]; prev != nil {
|
||
|
err := errors.New("extension number %d is already registered on message %v", field, message)
|
||
|
err = amendErrorWithCaller(err, prev, xt)
|
||
|
if !(r == GlobalTypes && ignoreConflict(xd, err)) {
|
||
|
return err
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if err := r.register("extension", xd, xt); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
if r.extensionsByMessage == nil {
|
||
|
r.extensionsByMessage = make(extensionsByMessage)
|
||
|
}
|
||
|
if r.extensionsByMessage[message] == nil {
|
||
|
r.extensionsByMessage[message] = make(extensionsByNumber)
|
||
|
}
|
||
|
r.extensionsByMessage[message][field] = xt
|
||
|
r.numExtensions++
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (r *Types) register(kind string, desc protoreflect.Descriptor, typ interface{}) error {
|
||
|
name := desc.FullName()
|
||
|
prev := r.typesByName[name]
|
||
|
if prev != nil {
|
||
|
err := errors.New("%v %v is already registered", kind, name)
|
||
|
err = amendErrorWithCaller(err, prev, typ)
|
||
|
if !(r == GlobalTypes && ignoreConflict(desc, err)) {
|
||
|
return err
|
||
|
}
|
||
|
}
|
||
|
if r.typesByName == nil {
|
||
|
r.typesByName = make(typesByName)
|
||
|
}
|
||
|
r.typesByName[name] = typ
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// FindEnumByName looks up an enum by its full name.
|
||
|
// E.g., "google.protobuf.Field.Kind".
|
||
|
//
|
||
|
// This returns (nil, NotFound) if not found.
|
||
|
func (r *Types) FindEnumByName(enum protoreflect.FullName) (protoreflect.EnumType, error) {
|
||
|
if r == nil {
|
||
|
return nil, NotFound
|
||
|
}
|
||
|
if r == GlobalTypes {
|
||
|
globalMutex.RLock()
|
||
|
defer globalMutex.RUnlock()
|
||
|
}
|
||
|
if v := r.typesByName[enum]; v != nil {
|
||
|
if et, _ := v.(protoreflect.EnumType); et != nil {
|
||
|
return et, nil
|
||
|
}
|
||
|
return nil, errors.New("found wrong type: got %v, want enum", typeName(v))
|
||
|
}
|
||
|
return nil, NotFound
|
||
|
}
|
||
|
|
||
|
// FindMessageByName looks up a message by its full name,
|
||
|
// e.g. "google.protobuf.Any".
|
||
|
//
|
||
|
// This returns (nil, NotFound) if not found.
|
||
|
func (r *Types) FindMessageByName(message protoreflect.FullName) (protoreflect.MessageType, error) {
|
||
|
if r == nil {
|
||
|
return nil, NotFound
|
||
|
}
|
||
|
if r == GlobalTypes {
|
||
|
globalMutex.RLock()
|
||
|
defer globalMutex.RUnlock()
|
||
|
}
|
||
|
if v := r.typesByName[message]; v != nil {
|
||
|
if mt, _ := v.(protoreflect.MessageType); mt != nil {
|
||
|
return mt, nil
|
||
|
}
|
||
|
return nil, errors.New("found wrong type: got %v, want message", typeName(v))
|
||
|
}
|
||
|
return nil, NotFound
|
||
|
}
|
||
|
|
||
|
// FindMessageByURL looks up a message by a URL identifier.
|
||
|
// See documentation on google.protobuf.Any.type_url for the URL format.
|
||
|
//
|
||
|
// This returns (nil, NotFound) if not found.
|
||
|
func (r *Types) FindMessageByURL(url string) (protoreflect.MessageType, error) {
|
||
|
// This function is similar to FindMessageByName but
|
||
|
// truncates anything before and including '/' in the URL.
|
||
|
if r == nil {
|
||
|
return nil, NotFound
|
||
|
}
|
||
|
if r == GlobalTypes {
|
||
|
globalMutex.RLock()
|
||
|
defer globalMutex.RUnlock()
|
||
|
}
|
||
|
message := protoreflect.FullName(url)
|
||
|
if i := strings.LastIndexByte(url, '/'); i >= 0 {
|
||
|
message = message[i+len("/"):]
|
||
|
}
|
||
|
|
||
|
if v := r.typesByName[message]; v != nil {
|
||
|
if mt, _ := v.(protoreflect.MessageType); mt != nil {
|
||
|
return mt, nil
|
||
|
}
|
||
|
return nil, errors.New("found wrong type: got %v, want message", typeName(v))
|
||
|
}
|
||
|
return nil, NotFound
|
||
|
}
|
||
|
|
||
|
// FindExtensionByName looks up a extension field by the field's full name.
|
||
|
// Note that this is the full name of the field as determined by
|
||
|
// where the extension is declared and is unrelated to the full name of the
|
||
|
// message being extended.
|
||
|
//
|
||
|
// This returns (nil, NotFound) if not found.
|
||
|
func (r *Types) FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error) {
|
||
|
if r == nil {
|
||
|
return nil, NotFound
|
||
|
}
|
||
|
if r == GlobalTypes {
|
||
|
globalMutex.RLock()
|
||
|
defer globalMutex.RUnlock()
|
||
|
}
|
||
|
if v := r.typesByName[field]; v != nil {
|
||
|
if xt, _ := v.(protoreflect.ExtensionType); xt != nil {
|
||
|
return xt, nil
|
||
|
}
|
||
|
|
||
|
// MessageSet extensions are special in that the name of the extension
|
||
|
// is the name of the message type used to extend the MessageSet.
|
||
|
// This naming scheme is used by text and JSON serialization.
|
||
|
//
|
||
|
// This feature is protected by the ProtoLegacy flag since MessageSets
|
||
|
// are a proto1 feature that is long deprecated.
|
||
|
if flags.ProtoLegacy {
|
||
|
if _, ok := v.(protoreflect.MessageType); ok {
|
||
|
field := field.Append(messageset.ExtensionName)
|
||
|
if v := r.typesByName[field]; v != nil {
|
||
|
if xt, _ := v.(protoreflect.ExtensionType); xt != nil {
|
||
|
if messageset.IsMessageSetExtension(xt.TypeDescriptor()) {
|
||
|
return xt, nil
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return nil, errors.New("found wrong type: got %v, want extension", typeName(v))
|
||
|
}
|
||
|
return nil, NotFound
|
||
|
}
|
||
|
|
||
|
// FindExtensionByNumber looks up a extension field by the field number
|
||
|
// within some parent message, identified by full name.
|
||
|
//
|
||
|
// This returns (nil, NotFound) if not found.
|
||
|
func (r *Types) FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error) {
|
||
|
if r == nil {
|
||
|
return nil, NotFound
|
||
|
}
|
||
|
if r == GlobalTypes {
|
||
|
globalMutex.RLock()
|
||
|
defer globalMutex.RUnlock()
|
||
|
}
|
||
|
if xt, ok := r.extensionsByMessage[message][field]; ok {
|
||
|
return xt, nil
|
||
|
}
|
||
|
return nil, NotFound
|
||
|
}
|
||
|
|
||
|
// NumEnums reports the number of registered enums.
|
||
|
func (r *Types) NumEnums() int {
|
||
|
if r == nil {
|
||
|
return 0
|
||
|
}
|
||
|
if r == GlobalTypes {
|
||
|
globalMutex.RLock()
|
||
|
defer globalMutex.RUnlock()
|
||
|
}
|
||
|
return r.numEnums
|
||
|
}
|
||
|
|
||
|
// RangeEnums iterates over all registered enums while f returns true.
|
||
|
// Iteration order is undefined.
|
||
|
func (r *Types) RangeEnums(f func(protoreflect.EnumType) bool) {
|
||
|
if r == nil {
|
||
|
return
|
||
|
}
|
||
|
if r == GlobalTypes {
|
||
|
globalMutex.RLock()
|
||
|
defer globalMutex.RUnlock()
|
||
|
}
|
||
|
for _, typ := range r.typesByName {
|
||
|
if et, ok := typ.(protoreflect.EnumType); ok {
|
||
|
if !f(et) {
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// NumMessages reports the number of registered messages.
|
||
|
func (r *Types) NumMessages() int {
|
||
|
if r == nil {
|
||
|
return 0
|
||
|
}
|
||
|
if r == GlobalTypes {
|
||
|
globalMutex.RLock()
|
||
|
defer globalMutex.RUnlock()
|
||
|
}
|
||
|
return r.numMessages
|
||
|
}
|
||
|
|
||
|
// RangeMessages iterates over all registered messages while f returns true.
|
||
|
// Iteration order is undefined.
|
||
|
func (r *Types) RangeMessages(f func(protoreflect.MessageType) bool) {
|
||
|
if r == nil {
|
||
|
return
|
||
|
}
|
||
|
if r == GlobalTypes {
|
||
|
globalMutex.RLock()
|
||
|
defer globalMutex.RUnlock()
|
||
|
}
|
||
|
for _, typ := range r.typesByName {
|
||
|
if mt, ok := typ.(protoreflect.MessageType); ok {
|
||
|
if !f(mt) {
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// NumExtensions reports the number of registered extensions.
|
||
|
func (r *Types) NumExtensions() int {
|
||
|
if r == nil {
|
||
|
return 0
|
||
|
}
|
||
|
if r == GlobalTypes {
|
||
|
globalMutex.RLock()
|
||
|
defer globalMutex.RUnlock()
|
||
|
}
|
||
|
return r.numExtensions
|
||
|
}
|
||
|
|
||
|
// RangeExtensions iterates over all registered extensions while f returns true.
|
||
|
// Iteration order is undefined.
|
||
|
func (r *Types) RangeExtensions(f func(protoreflect.ExtensionType) bool) {
|
||
|
if r == nil {
|
||
|
return
|
||
|
}
|
||
|
if r == GlobalTypes {
|
||
|
globalMutex.RLock()
|
||
|
defer globalMutex.RUnlock()
|
||
|
}
|
||
|
for _, typ := range r.typesByName {
|
||
|
if xt, ok := typ.(protoreflect.ExtensionType); ok {
|
||
|
if !f(xt) {
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// NumExtensionsByMessage reports the number of registered extensions for
|
||
|
// a given message type.
|
||
|
func (r *Types) NumExtensionsByMessage(message protoreflect.FullName) int {
|
||
|
if r == nil {
|
||
|
return 0
|
||
|
}
|
||
|
if r == GlobalTypes {
|
||
|
globalMutex.RLock()
|
||
|
defer globalMutex.RUnlock()
|
||
|
}
|
||
|
return len(r.extensionsByMessage[message])
|
||
|
}
|
||
|
|
||
|
// RangeExtensionsByMessage iterates over all registered extensions filtered
|
||
|
// by a given message type while f returns true. Iteration order is undefined.
|
||
|
func (r *Types) RangeExtensionsByMessage(message protoreflect.FullName, f func(protoreflect.ExtensionType) bool) {
|
||
|
if r == nil {
|
||
|
return
|
||
|
}
|
||
|
if r == GlobalTypes {
|
||
|
globalMutex.RLock()
|
||
|
defer globalMutex.RUnlock()
|
||
|
}
|
||
|
for _, xt := range r.extensionsByMessage[message] {
|
||
|
if !f(xt) {
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func typeName(t interface{}) string {
|
||
|
switch t.(type) {
|
||
|
case protoreflect.EnumType:
|
||
|
return "enum"
|
||
|
case protoreflect.MessageType:
|
||
|
return "message"
|
||
|
case protoreflect.ExtensionType:
|
||
|
return "extension"
|
||
|
default:
|
||
|
return fmt.Sprintf("%T", t)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func amendErrorWithCaller(err error, prev, curr interface{}) error {
|
||
|
prevPkg := goPackage(prev)
|
||
|
currPkg := goPackage(curr)
|
||
|
if prevPkg == "" || currPkg == "" || prevPkg == currPkg {
|
||
|
return err
|
||
|
}
|
||
|
return errors.New("%s\n\tpreviously from: %q\n\tcurrently from: %q", err, prevPkg, currPkg)
|
||
|
}
|
||
|
|
||
|
func goPackage(v interface{}) string {
|
||
|
switch d := v.(type) {
|
||
|
case protoreflect.EnumType:
|
||
|
v = d.Descriptor()
|
||
|
case protoreflect.MessageType:
|
||
|
v = d.Descriptor()
|
||
|
case protoreflect.ExtensionType:
|
||
|
v = d.TypeDescriptor()
|
||
|
}
|
||
|
if d, ok := v.(protoreflect.Descriptor); ok {
|
||
|
v = d.ParentFile()
|
||
|
}
|
||
|
if d, ok := v.(interface{ GoPackagePath() string }); ok {
|
||
|
return d.GoPackagePath()
|
||
|
}
|
||
|
return ""
|
||
|
}
|