woodpecker/vendor/github.com/quasilyte/go-ruleguard/ruleguard/irconv/irconv.go
2021-11-16 21:07:53 +01:00

630 lines
16 KiB
Go

package irconv
import (
"fmt"
"go/ast"
"go/constant"
"go/token"
"go/types"
"path"
"strconv"
"strings"
"github.com/quasilyte/go-ruleguard/ruleguard/goutil"
"github.com/quasilyte/go-ruleguard/ruleguard/ir"
"golang.org/x/tools/go/ast/astutil"
)
type Context struct {
Pkg *types.Package
Types *types.Info
Fset *token.FileSet
Src []byte
}
func ConvertFile(ctx *Context, f *ast.File) (result *ir.File, err error) {
defer func() {
if err != nil {
return
}
rv := recover()
if rv == nil {
return
}
if convErr, ok := rv.(convError); ok {
err = convErr.err
return
}
panic(rv) // not our panic
}()
conv := &converter{
types: ctx.Types,
pkg: ctx.Pkg,
fset: ctx.Fset,
src: ctx.Src,
}
result = conv.ConvertFile(f)
return result, nil
}
type convError struct {
err error
}
type converter struct {
types *types.Info
pkg *types.Package
fset *token.FileSet
src []byte
group *ir.RuleGroup
dslPkgname string // The local name of the "ruleguard/dsl" package (usually its just "dsl")
}
func (conv *converter) errorf(n ast.Node, format string, args ...interface{}) convError {
loc := conv.fset.Position(n.Pos())
msg := fmt.Sprintf(format, args...)
return convError{err: fmt.Errorf("%s:%d: %s", loc.Filename, loc.Line, msg)}
}
func (conv *converter) ConvertFile(f *ast.File) *ir.File {
result := &ir.File{
PkgPath: conv.pkg.Path(),
}
conv.dslPkgname = "dsl"
for _, imp := range f.Imports {
importPath, err := strconv.Unquote(imp.Path.Value)
if err != nil {
panic(conv.errorf(imp, "unquote %s import path: %s", imp.Path.Value, err))
}
if importPath == "github.com/quasilyte/go-ruleguard/dsl" {
if imp.Name != nil {
conv.dslPkgname = imp.Name.Name
}
}
}
for _, decl := range f.Decls {
funcDecl, ok := decl.(*ast.FuncDecl)
if !ok {
genDecl := decl.(*ast.GenDecl)
if genDecl.Tok != token.IMPORT {
conv.addCustomDecl(result, decl)
}
continue
}
if funcDecl.Name.String() == "init" {
conv.convertInitFunc(result, funcDecl)
continue
}
if conv.isMatcherFunc(funcDecl) {
result.RuleGroups = append(result.RuleGroups, *conv.convertRuleGroup(funcDecl))
} else {
conv.addCustomDecl(result, funcDecl)
}
}
return result
}
func (conv *converter) convertInitFunc(dst *ir.File, decl *ast.FuncDecl) {
for _, stmt := range decl.Body.List {
exprStmt, ok := stmt.(*ast.ExprStmt)
if !ok {
panic(conv.errorf(stmt, "unsupported statement"))
}
call, ok := exprStmt.X.(*ast.CallExpr)
if !ok {
panic(conv.errorf(stmt, "unsupported expr"))
}
fn, ok := call.Fun.(*ast.SelectorExpr)
if !ok {
panic(conv.errorf(stmt, "unsupported call"))
}
pkg, ok := fn.X.(*ast.Ident)
if !ok || pkg.Name != conv.dslPkgname {
panic(conv.errorf(stmt, "unsupported call"))
}
switch fn.Sel.Name {
case "ImportRules":
prefix := conv.parseStringArg(call.Args[0])
bundleSelector, ok := call.Args[1].(*ast.SelectorExpr)
if !ok {
panic(conv.errorf(call.Args[1], "expected a `pkgname.Bundle` argument"))
}
bundleObj := conv.types.ObjectOf(bundleSelector.Sel)
dst.BundleImports = append(dst.BundleImports, ir.BundleImport{
Prefix: prefix,
PkgPath: bundleObj.Pkg().Path(),
Line: conv.fset.Position(exprStmt.Pos()).Line,
})
default:
panic(conv.errorf(stmt, "unsupported %s call", fn.Sel.Name))
}
}
}
func (conv *converter) addCustomDecl(dst *ir.File, decl ast.Decl) {
begin := conv.fset.Position(decl.Pos())
end := conv.fset.Position(decl.End())
src := conv.src[begin.Offset:end.Offset]
dst.CustomDecls = append(dst.CustomDecls, string(src))
}
func (conv *converter) isMatcherFunc(f *ast.FuncDecl) bool {
typ := conv.types.ObjectOf(f.Name).Type().(*types.Signature)
return typ.Results().Len() == 0 &&
typ.Params().Len() == 1 &&
typ.Params().At(0).Type().String() == "github.com/quasilyte/go-ruleguard/dsl.Matcher"
}
func (conv *converter) convertRuleGroup(decl *ast.FuncDecl) *ir.RuleGroup {
result := &ir.RuleGroup{
Line: conv.fset.Position(decl.Name.Pos()).Line,
}
conv.group = result
result.Name = decl.Name.String()
result.MatcherName = decl.Type.Params.List[0].Names[0].String()
if decl.Doc != nil {
conv.convertDocComments(decl.Doc)
}
seenRules := false
for _, stmt := range decl.Body.List {
if _, ok := stmt.(*ast.DeclStmt); ok {
continue
}
stmtExpr, ok := stmt.(*ast.ExprStmt)
if !ok {
panic(conv.errorf(stmt, "expected a %s method call, found %s", result.MatcherName, goutil.SprintNode(conv.fset, stmt)))
}
call, ok := stmtExpr.X.(*ast.CallExpr)
if !ok {
panic(conv.errorf(stmt, "expected a %s method call, found %s", result.MatcherName, goutil.SprintNode(conv.fset, stmt)))
}
switch conv.matcherMethodName(call) {
case "Import":
if seenRules {
panic(conv.errorf(call, "Import() should be used before any rules definitions"))
}
conv.doMatcherImport(call)
default:
seenRules = true
conv.convertRuleExpr(call)
}
}
return result
}
func (conv *converter) doMatcherImport(call *ast.CallExpr) {
pkgPath := conv.parseStringArg(call.Args[0])
pkgName := path.Base(pkgPath)
conv.group.Imports = append(conv.group.Imports, ir.PackageImport{
Path: pkgPath,
Name: pkgName,
})
}
func (conv *converter) matcherMethodName(call *ast.CallExpr) string {
selector, ok := call.Fun.(*ast.SelectorExpr)
if !ok {
return ""
}
id, ok := selector.X.(*ast.Ident)
if !ok || id.Name != conv.group.MatcherName {
return ""
}
return selector.Sel.Name
}
func (conv *converter) convertDocComments(comment *ast.CommentGroup) {
knownPragmas := []string{
"tags",
"summary",
"before",
"after",
"note",
}
for _, c := range comment.List {
if !strings.HasPrefix(c.Text, "//doc:") {
continue
}
s := strings.TrimPrefix(c.Text, "//doc:")
var pragma string
for i := range knownPragmas {
if strings.HasPrefix(s, knownPragmas[i]) {
pragma = knownPragmas[i]
break
}
}
if pragma == "" {
panic(conv.errorf(c, "unrecognized 'doc' pragma in comment"))
}
s = strings.TrimPrefix(s, pragma)
s = strings.TrimSpace(s)
switch pragma {
case "summary":
conv.group.DocSummary = s
case "before":
conv.group.DocBefore = s
case "after":
conv.group.DocAfter = s
case "note":
conv.group.DocNote = s
case "tags":
conv.group.DocTags = strings.Fields(s)
default:
panic("unhandled 'doc' pragma: " + pragma) // Should never happen
}
}
}
func (conv *converter) convertRuleExpr(call *ast.CallExpr) {
origCall := call
var (
matchArgs *[]ast.Expr
matchCommentArgs *[]ast.Expr
whereArgs *[]ast.Expr
suggestArgs *[]ast.Expr
reportArgs *[]ast.Expr
atArgs *[]ast.Expr
)
for {
chain, ok := call.Fun.(*ast.SelectorExpr)
if !ok {
break
}
switch chain.Sel.Name {
case "Match":
if matchArgs != nil {
panic(conv.errorf(chain.Sel, "Match() can't be repeated"))
}
if matchCommentArgs != nil {
panic(conv.errorf(chain.Sel, "Match() and MatchComment() can't be combined"))
}
matchArgs = &call.Args
case "MatchComment":
if matchCommentArgs != nil {
panic(conv.errorf(chain.Sel, "MatchComment() can't be repeated"))
}
if matchArgs != nil {
panic(conv.errorf(chain.Sel, "Match() and MatchComment() can't be combined"))
}
matchCommentArgs = &call.Args
case "Where":
if whereArgs != nil {
panic(conv.errorf(chain.Sel, "Where() can't be repeated"))
}
whereArgs = &call.Args
case "Suggest":
if suggestArgs != nil {
panic(conv.errorf(chain.Sel, "Suggest() can't be repeated"))
}
suggestArgs = &call.Args
case "Report":
if reportArgs != nil {
panic(conv.errorf(chain.Sel, "Report() can't be repeated"))
}
reportArgs = &call.Args
case "At":
if atArgs != nil {
panic(conv.errorf(chain.Sel, "At() can't be repeated"))
}
atArgs = &call.Args
default:
panic(conv.errorf(chain.Sel, "unexpected %s method", chain.Sel.Name))
}
call, ok = chain.X.(*ast.CallExpr)
if !ok {
break
}
}
// AST patterns for Match() or regexp patterns for MatchComment().
var alternatives []string
var alternativeLines []int
if matchArgs == nil && matchCommentArgs == nil {
panic(conv.errorf(origCall, "missing Match() or MatchComment() call"))
}
if matchArgs != nil {
for _, arg := range *matchArgs {
alternatives = append(alternatives, conv.parseStringArg(arg))
alternativeLines = append(alternativeLines, conv.fset.Position(arg.Pos()).Line)
}
} else {
for _, arg := range *matchCommentArgs {
alternatives = append(alternatives, conv.parseStringArg(arg))
alternativeLines = append(alternativeLines, conv.fset.Position(arg.Pos()).Line)
}
}
rule := ir.Rule{Line: conv.fset.Position(origCall.Pos()).Line}
if atArgs != nil {
index, ok := (*atArgs)[0].(*ast.IndexExpr)
if !ok {
panic(conv.errorf((*atArgs)[0], "expected %s[`varname`] expression", conv.group.MatcherName))
}
rule.LocationVar = conv.parseStringArg(index.Index)
}
if whereArgs != nil {
rule.WhereExpr = conv.convertFilterExpr((*whereArgs)[0])
}
if suggestArgs != nil {
rule.SuggestTemplate = conv.parseStringArg((*suggestArgs)[0])
}
if suggestArgs == nil && reportArgs == nil {
panic(conv.errorf(origCall, "missing Report() or Suggest() call"))
}
if reportArgs == nil {
rule.ReportTemplate = "suggestion: " + rule.SuggestTemplate
} else {
rule.ReportTemplate = conv.parseStringArg((*reportArgs)[0])
}
for i, alt := range alternatives {
pat := ir.PatternString{
Line: alternativeLines[i],
Value: alt,
}
if matchArgs != nil {
rule.SyntaxPatterns = append(rule.SyntaxPatterns, pat)
} else {
rule.CommentPatterns = append(rule.CommentPatterns, pat)
}
}
conv.group.Rules = append(conv.group.Rules, rule)
}
func (conv *converter) convertFilterExpr(e ast.Expr) ir.FilterExpr {
result := conv.convertFilterExprImpl(e)
result.Src = goutil.SprintNode(conv.fset, e)
result.Line = conv.fset.Position(e.Pos()).Line
if !result.IsValid() {
panic(conv.errorf(e, "unsupported expr: %s (%T)", result.Src, e))
}
return result
}
func (conv *converter) convertFilterExprImpl(e ast.Expr) ir.FilterExpr {
if cv := conv.types.Types[e].Value; cv != nil {
switch cv.Kind() {
case constant.String:
v := constant.StringVal(cv)
return ir.FilterExpr{Op: ir.FilterStringOp, Value: v}
case constant.Int:
v, ok := constant.Int64Val(cv)
if ok {
return ir.FilterExpr{Op: ir.FilterIntOp, Value: v}
}
}
}
convertExprList := func(list []ast.Expr) []ir.FilterExpr {
if len(list) == 0 {
return nil
}
result := make([]ir.FilterExpr, len(list))
for i, e := range list {
result[i] = conv.convertFilterExpr(e)
}
return result
}
switch e := e.(type) {
case *ast.ParenExpr:
return conv.convertFilterExpr(e.X)
case *ast.UnaryExpr:
x := conv.convertFilterExpr(e.X)
args := []ir.FilterExpr{x}
switch e.Op {
case token.NOT:
return ir.FilterExpr{Op: ir.FilterNotOp, Args: args}
}
case *ast.BinaryExpr:
x := conv.convertFilterExpr(e.X)
y := conv.convertFilterExpr(e.Y)
args := []ir.FilterExpr{x, y}
switch e.Op {
case token.LAND:
return ir.FilterExpr{Op: ir.FilterAndOp, Args: args}
case token.LOR:
return ir.FilterExpr{Op: ir.FilterOrOp, Args: args}
case token.NEQ:
return ir.FilterExpr{Op: ir.FilterNeqOp, Args: args}
case token.EQL:
return ir.FilterExpr{Op: ir.FilterEqOp, Args: args}
case token.GTR:
return ir.FilterExpr{Op: ir.FilterGtOp, Args: args}
case token.LSS:
return ir.FilterExpr{Op: ir.FilterLtOp, Args: args}
case token.GEQ:
return ir.FilterExpr{Op: ir.FilterGtEqOp, Args: args}
case token.LEQ:
return ir.FilterExpr{Op: ir.FilterLtEqOp, Args: args}
default:
panic(conv.errorf(e, "unexpected binary op: %s", e.Op.String()))
}
case *ast.SelectorExpr:
op := conv.inspectFilterSelector(e)
switch op.path {
case "Text":
return ir.FilterExpr{Op: ir.FilterVarTextOp, Value: op.varName}
case "Line":
return ir.FilterExpr{Op: ir.FilterVarLineOp, Value: op.varName}
case "Pure":
return ir.FilterExpr{Op: ir.FilterVarPureOp, Value: op.varName}
case "Const":
return ir.FilterExpr{Op: ir.FilterVarConstOp, Value: op.varName}
case "ConstSlice":
return ir.FilterExpr{Op: ir.FilterVarConstSliceOp, Value: op.varName}
case "Addressable":
return ir.FilterExpr{Op: ir.FilterVarAddressableOp, Value: op.varName}
case "Type.Size":
return ir.FilterExpr{Op: ir.FilterVarTypeSizeOp, Value: op.varName}
}
case *ast.CallExpr:
op := conv.inspectFilterSelector(e)
switch op.path {
case "Deadcode":
return ir.FilterExpr{Op: ir.FilterDeadcodeOp}
case "GoVersion.Eq":
return ir.FilterExpr{Op: ir.FilterGoVersionEqOp, Value: conv.parseStringArg(e.Args[0])}
case "GoVersion.LessThan":
return ir.FilterExpr{Op: ir.FilterGoVersionLessThanOp, Value: conv.parseStringArg(e.Args[0])}
case "GoVersion.GreaterThan":
return ir.FilterExpr{Op: ir.FilterGoVersionGreaterThanOp, Value: conv.parseStringArg(e.Args[0])}
case "GoVersion.LessEqThan":
return ir.FilterExpr{Op: ir.FilterGoVersionLessEqThanOp, Value: conv.parseStringArg(e.Args[0])}
case "GoVersion.GreaterEqThan":
return ir.FilterExpr{Op: ir.FilterGoVersionGreaterEqThanOp, Value: conv.parseStringArg(e.Args[0])}
case "File.Imports":
return ir.FilterExpr{Op: ir.FilterFileImportsOp, Value: conv.parseStringArg(e.Args[0])}
case "File.PkgPath.Matches":
return ir.FilterExpr{Op: ir.FilterFilePkgPathMatchesOp, Value: conv.parseStringArg(e.Args[0])}
case "File.Name.Matches":
return ir.FilterExpr{Op: ir.FilterFileNameMatchesOp, Value: conv.parseStringArg(e.Args[0])}
case "Filter":
funcName, ok := e.Args[0].(*ast.Ident)
if !ok {
panic(conv.errorf(e.Args[0], "only named function args are supported"))
}
args := []ir.FilterExpr{
{Op: ir.FilterFilterFuncRefOp, Value: funcName.String()},
}
return ir.FilterExpr{Op: ir.FilterVarFilterOp, Value: op.varName, Args: args}
}
args := convertExprList(e.Args)
switch op.path {
case "Value.Int":
return ir.FilterExpr{Op: ir.FilterVarValueIntOp, Value: op.varName, Args: args}
case "Text.Matches":
return ir.FilterExpr{Op: ir.FilterVarTextMatchesOp, Value: op.varName, Args: args}
case "Node.Is":
return ir.FilterExpr{Op: ir.FilterVarNodeIsOp, Value: op.varName, Args: args}
case "Node.Parent.Is":
if op.varName != "$$" {
// TODO: remove this restriction.
panic(conv.errorf(e.Args[0], "only $$ parent nodes are implemented"))
}
return ir.FilterExpr{Op: ir.FilterRootNodeParentIsOp, Args: args}
case "Object.Is":
return ir.FilterExpr{Op: ir.FilterVarObjectIsOp, Value: op.varName, Args: args}
case "Type.Is":
return ir.FilterExpr{Op: ir.FilterVarTypeIsOp, Value: op.varName, Args: args}
case "Type.Underlying.Is":
return ir.FilterExpr{Op: ir.FilterVarTypeUnderlyingIsOp, Value: op.varName, Args: args}
case "Type.ConvertibleTo":
return ir.FilterExpr{Op: ir.FilterVarTypeConvertibleToOp, Value: op.varName, Args: args}
case "Type.AssignableTo":
return ir.FilterExpr{Op: ir.FilterVarTypeAssignableToOp, Value: op.varName, Args: args}
case "Type.Implements":
return ir.FilterExpr{Op: ir.FilterVarTypeImplementsOp, Value: op.varName, Args: args}
}
}
return ir.FilterExpr{}
}
func (conv *converter) parseStringArg(e ast.Expr) string {
s, ok := conv.toStringValue(e)
if !ok {
panic(conv.errorf(e, "expected a string literal argument"))
}
return s
}
func (conv *converter) toStringValue(x ast.Node) (string, bool) {
switch x := x.(type) {
case *ast.BasicLit:
if x.Kind != token.STRING {
return "", false
}
s, err := strconv.Unquote(x.Value)
if err != nil {
return "", false
}
return s, true
case ast.Expr:
typ, ok := conv.types.Types[x]
if !ok || typ.Type.String() != "string" {
return "", false
}
str := constant.StringVal(typ.Value)
return str, true
}
return "", false
}
func (conv *converter) inspectFilterSelector(e ast.Expr) filterExprSelector {
var o filterExprSelector
if call, ok := e.(*ast.CallExpr); ok {
o.args = call.Args
e = call.Fun
}
var path string
for {
if call, ok := e.(*ast.CallExpr); ok {
e = call.Fun
continue
}
selector, ok := e.(*ast.SelectorExpr)
if !ok {
break
}
if path == "" {
path = selector.Sel.Name
} else {
path = selector.Sel.Name + "." + path
}
e = astutil.Unparen(selector.X)
}
o.path = path
indexing, ok := astutil.Unparen(e).(*ast.IndexExpr)
if !ok {
return o
}
mapIdent, ok := astutil.Unparen(indexing.X).(*ast.Ident)
if !ok {
return o
}
o.mapName = mapIdent.Name
indexString, _ := conv.toStringValue(indexing.Index)
o.varName = indexString
return o
}
type filterExprSelector struct {
mapName string
varName string
path string
args []ast.Expr
}