package quasigo import ( "fmt" "go/ast" "go/constant" "go/token" "go/types" "github.com/quasilyte/go-ruleguard/ruleguard/goutil" "golang.org/x/tools/go/ast/astutil" ) func compile(ctx *CompileContext, fn *ast.FuncDecl) (compiled *Func, err error) { defer func() { if err != nil { return } rv := recover() if rv == nil { return } if compileErr, ok := rv.(compileError); ok { err = compileErr return } panic(rv) // not our panic }() return compileFunc(ctx, fn), nil } func compileFunc(ctx *CompileContext, fn *ast.FuncDecl) *Func { cl := compiler{ ctx: ctx, fnType: ctx.Types.ObjectOf(fn.Name).Type().(*types.Signature), constantsPool: make(map[interface{}]int), intConstantsPool: make(map[int]int), locals: make(map[string]int), } return cl.compileFunc(fn) } type compiler struct { ctx *CompileContext fnType *types.Signature retType types.Type lastOp opcode locals map[string]int constantsPool map[interface{}]int intConstantsPool map[int]int params map[string]int code []byte constants []interface{} intConstants []int breakTarget *label continueTarget *label labels []*label } type label struct { targetPos int sources []int } type compileError string func (e compileError) Error() string { return string(e) } func (cl *compiler) compileFunc(fn *ast.FuncDecl) *Func { if cl.fnType.Results().Len() != 1 { panic(cl.errorf(fn.Name, "only functions with a single non-void results are supported")) } cl.retType = cl.fnType.Results().At(0).Type() if !cl.isSupportedType(cl.retType) { panic(cl.errorUnsupportedType(fn.Name, cl.retType, "function result")) } dbg := funcDebugInfo{ paramNames: make([]string, cl.fnType.Params().Len()), } cl.params = make(map[string]int, cl.fnType.Params().Len()) for i := 0; i < cl.fnType.Params().Len(); i++ { p := cl.fnType.Params().At(i) paramName := p.Name() paramType := p.Type() cl.params[paramName] = i dbg.paramNames[i] = paramName if !cl.isSupportedType(paramType) { panic(cl.errorUnsupportedType(fn.Name, paramType, paramName+" param")) } } cl.compileStmt(fn.Body) compiled := &Func{ code: cl.code, constants: cl.constants, intConstants: cl.intConstants, } if len(cl.locals) != 0 { dbg.localNames = make([]string, len(cl.locals)) for localName, localIndex := range cl.locals { dbg.localNames[localIndex] = localName } } cl.ctx.Env.debug.funcs[compiled] = dbg cl.linkJumps() return compiled } func (cl *compiler) compileStmt(stmt ast.Stmt) { switch stmt := stmt.(type) { case *ast.ReturnStmt: cl.compileReturnStmt(stmt) case *ast.AssignStmt: cl.compileAssignStmt(stmt) case *ast.IncDecStmt: cl.compileIncDecStmt(stmt) case *ast.IfStmt: cl.compileIfStmt(stmt) case *ast.ForStmt: cl.compileForStmt(stmt) case *ast.BranchStmt: cl.compileBranchStmt(stmt) case *ast.BlockStmt: for i := range stmt.List { cl.compileStmt(stmt.List[i]) } default: panic(cl.errorf(stmt, "can't compile %T yet", stmt)) } } func (cl *compiler) compileIncDecStmt(stmt *ast.IncDecStmt) { varname, ok := stmt.X.(*ast.Ident) if !ok { panic(cl.errorf(stmt.X, "can assign only to simple variables")) } id := cl.getLocal(varname, varname.String()) if stmt.Tok == token.INC { cl.emit8(opIncLocal, id) } else { cl.emit8(opDecLocal, id) } } func (cl *compiler) compileBranchStmt(branch *ast.BranchStmt) { if branch.Label != nil { panic(cl.errorf(branch.Label, "can't compile %s with a label", branch.Tok)) } switch branch.Tok { case token.BREAK: cl.emitJump(opJump, cl.breakTarget) default: panic(cl.errorf(branch, "can't compile %s yet", branch.Tok)) } } func (cl *compiler) compileForStmt(stmt *ast.ForStmt) { labelBreak := cl.newLabel() labelContinue := cl.newLabel() prevBreakTarget := cl.breakTarget prevContinueTarget := cl.continueTarget cl.breakTarget = labelBreak cl.continueTarget = labelContinue switch { case stmt.Cond != nil && stmt.Init != nil && stmt.Post != nil: // Will be implemented later; probably when the max number of locals will be lifted. panic(cl.errorf(stmt, "can't compile C-style for loops yet")) case stmt.Cond != nil && stmt.Init == nil && stmt.Post == nil: // `for { ... }` labelBody := cl.newLabel() cl.emitJump(opJump, labelContinue) cl.bindLabel(labelBody) cl.compileStmt(stmt.Body) cl.bindLabel(labelContinue) cl.compileExpr(stmt.Cond) cl.emitJump(opJumpTrue, labelBody) cl.bindLabel(labelBreak) default: // `for { ... }` cl.bindLabel(labelContinue) cl.compileStmt(stmt.Body) cl.emitJump(opJump, labelContinue) cl.bindLabel(labelBreak) } cl.breakTarget = prevBreakTarget cl.continueTarget = prevContinueTarget } func (cl *compiler) compileIfStmt(stmt *ast.IfStmt) { if stmt.Else == nil { labelEnd := cl.newLabel() cl.compileExpr(stmt.Cond) cl.emitJump(opJumpFalse, labelEnd) cl.compileStmt(stmt.Body) cl.bindLabel(labelEnd) return } labelEnd := cl.newLabel() labelElse := cl.newLabel() cl.compileExpr(stmt.Cond) cl.emitJump(opJumpFalse, labelElse) cl.compileStmt(stmt.Body) if !cl.isUncondJump(cl.lastOp) { cl.emitJump(opJump, labelEnd) } cl.bindLabel(labelElse) cl.compileStmt(stmt.Else) cl.bindLabel(labelEnd) } func (cl *compiler) compileAssignStmt(assign *ast.AssignStmt) { if len(assign.Lhs) != 1 { panic(cl.errorf(assign, "only single left operand is allowed in assignments")) } if len(assign.Rhs) != 1 { panic(cl.errorf(assign, "only single right operand is allowed in assignments")) } lhs := assign.Lhs[0] rhs := assign.Rhs[0] varname, ok := lhs.(*ast.Ident) if !ok { panic(cl.errorf(lhs, "can assign only to simple variables")) } cl.compileExpr(rhs) typ := cl.ctx.Types.TypeOf(varname) if assign.Tok == token.DEFINE { if _, ok := cl.locals[varname.String()]; ok { panic(cl.errorf(lhs, "%s variable shadowing is not allowed", varname)) } if !cl.isSupportedType(typ) { panic(cl.errorUnsupportedType(varname, typ, varname.String()+" local variable")) } if len(cl.locals) == maxFuncLocals { panic(cl.errorf(lhs, "can't define %s: too many locals", varname)) } id := len(cl.locals) cl.locals[varname.String()] = id cl.emit8(pickOp(typeIsInt(typ), opSetIntLocal, opSetLocal), id) } else { id := cl.getLocal(varname, varname.String()) cl.emit8(pickOp(typeIsInt(typ), opSetIntLocal, opSetLocal), id) } } func (cl *compiler) getLocal(v ast.Expr, varname string) int { id, ok := cl.locals[varname] if !ok { if _, ok := cl.params[varname]; ok { panic(cl.errorf(v, "can't assign to %s, params are readonly", varname)) } panic(cl.errorf(v, "%s is not a writeable local variable", varname)) } return id } func (cl *compiler) compileReturnStmt(ret *ast.ReturnStmt) { if ret.Results == nil { panic(cl.errorf(ret, "'naked' return statements are not allowed")) } switch { case identName(ret.Results[0]) == "true": cl.emit(opReturnTrue) case identName(ret.Results[0]) == "false": cl.emit(opReturnFalse) default: cl.compileExpr(ret.Results[0]) typ := cl.ctx.Types.TypeOf(ret.Results[0]) cl.emit(pickOp(typeIsInt(typ), opReturnIntTop, opReturnTop)) } } func (cl *compiler) compileExpr(e ast.Expr) { cv := cl.ctx.Types.Types[e].Value if cv != nil { cl.compileConstantValue(e, cv) return } switch e := e.(type) { case *ast.ParenExpr: cl.compileExpr(e.X) case *ast.Ident: cl.compileIdent(e) case *ast.SelectorExpr: cl.compileSelectorExpr(e) case *ast.UnaryExpr: switch e.Op { case token.NOT: cl.compileUnaryOp(opNot, e) default: panic(cl.errorf(e, "can't compile unary %s yet", e.Op)) } case *ast.SliceExpr: cl.compileSliceExpr(e) case *ast.BinaryExpr: cl.compileBinaryExpr(e) case *ast.CallExpr: cl.compileCallExpr(e) default: panic(cl.errorf(e, "can't compile %T yet", e)) } } func (cl *compiler) compileSelectorExpr(e *ast.SelectorExpr) { typ := cl.ctx.Types.TypeOf(e.X) key := funcKey{ name: e.Sel.String(), qualifier: typ.String(), } if funcID, ok := cl.ctx.Env.nameToNativeFuncID[key]; ok { cl.compileExpr(e.X) cl.emit16(opCallNative, int(funcID)) return } panic(cl.errorf(e, "can't compile %s field access", e.Sel)) } func (cl *compiler) compileBinaryExpr(e *ast.BinaryExpr) { typ := cl.ctx.Types.TypeOf(e.X) switch e.Op { case token.LOR: cl.compileOr(e) case token.LAND: cl.compileAnd(e) case token.NEQ: switch { case identName(e.X) == "nil": cl.compileExpr(e.Y) cl.emit(opIsNotNil) case identName(e.Y) == "nil": cl.compileExpr(e.X) cl.emit(opIsNotNil) case typeIsString(typ): cl.compileBinaryOp(opNotEqString, e) case typeIsInt(typ): cl.compileBinaryOp(opNotEqInt, e) default: panic(cl.errorf(e, "!= is not implemented for %s operands", typ)) } case token.EQL: switch { case identName(e.X) == "nil": cl.compileExpr(e.Y) cl.emit(opIsNil) case identName(e.Y) == "nil": cl.compileExpr(e.X) cl.emit(opIsNil) case typeIsString(cl.ctx.Types.TypeOf(e.X)): cl.compileBinaryOp(opEqString, e) case typeIsInt(cl.ctx.Types.TypeOf(e.X)): cl.compileBinaryOp(opEqInt, e) default: panic(cl.errorf(e, "== is not implemented for %s operands", typ)) } case token.GTR: cl.compileIntBinaryOp(e, opGtInt, typ) case token.GEQ: cl.compileIntBinaryOp(e, opGtEqInt, typ) case token.LSS: cl.compileIntBinaryOp(e, opLtInt, typ) case token.LEQ: cl.compileIntBinaryOp(e, opLtEqInt, typ) case token.ADD: switch { case typeIsString(typ): cl.compileBinaryOp(opConcat, e) case typeIsInt(typ): cl.compileBinaryOp(opAdd, e) default: panic(cl.errorf(e, "+ is not implemented for %s operands", typ)) } case token.SUB: cl.compileIntBinaryOp(e, opSub, typ) default: panic(cl.errorf(e, "can't compile binary %s yet", e.Op)) } } func (cl *compiler) compileIntBinaryOp(e *ast.BinaryExpr, op opcode, typ types.Type) { switch { case typeIsInt(typ): cl.compileBinaryOp(op, e) default: panic(cl.errorf(e, "%s is not implemented for %s operands", e.Op, typ)) } } func (cl *compiler) compileSliceExpr(slice *ast.SliceExpr) { if slice.Slice3 { panic(cl.errorf(slice, "can't compile 3-index slicing")) } // No need to do slicing, its no-op `s[:]`. if slice.Low == nil && slice.High == nil { cl.compileExpr(slice.X) return } sliceOp := opStringSlice sliceFromOp := opStringSliceFrom sliceToOp := opStringSliceTo if !typeIsString(cl.ctx.Types.TypeOf(slice.X)) { panic(cl.errorf(slice.X, "can't compile slicing of something that is not a string")) } switch { case slice.Low == nil && slice.High != nil: cl.compileExpr(slice.X) cl.compileExpr(slice.High) cl.emit(sliceToOp) case slice.Low != nil && slice.High == nil: cl.compileExpr(slice.X) cl.compileExpr(slice.Low) cl.emit(sliceFromOp) default: cl.compileExpr(slice.X) cl.compileExpr(slice.Low) cl.compileExpr(slice.High) cl.emit(sliceOp) } } func (cl *compiler) compileBuiltinCall(fn *ast.Ident, call *ast.CallExpr) { switch fn.Name { case `len`: s := call.Args[0] cl.compileExpr(s) if !typeIsString(cl.ctx.Types.TypeOf(s)) { panic(cl.errorf(s, "can't compile len() with non-string argument yet")) } cl.emit(opStringLen) default: panic(cl.errorf(fn, "can't compile %s() builtin function call yet", fn)) } } func (cl *compiler) compileCallExpr(call *ast.CallExpr) { if id, ok := astutil.Unparen(call.Fun).(*ast.Ident); ok { _, isBuiltin := cl.ctx.Types.ObjectOf(id).(*types.Builtin) if isBuiltin { cl.compileBuiltinCall(id, call) return } } expr, fn := goutil.ResolveFunc(cl.ctx.Types, call.Fun) if fn == nil { panic(cl.errorf(call.Fun, "can't resolve the called function")) } // TODO: just use Func.FullName as a key? key := funcKey{name: fn.Name()} sig := fn.Type().(*types.Signature) if sig.Recv() != nil { key.qualifier = sig.Recv().Type().String() } else { key.qualifier = fn.Pkg().Path() } if funcID, ok := cl.ctx.Env.nameToNativeFuncID[key]; ok { if expr != nil { cl.compileExpr(expr) } for _, arg := range call.Args { cl.compileExpr(arg) } cl.emit16(opCallNative, int(funcID)) return } panic(cl.errorf(call.Fun, "can't compile a call to %s func", key)) } func (cl *compiler) compileUnaryOp(op opcode, e *ast.UnaryExpr) { cl.compileExpr(e.X) cl.emit(op) } func (cl *compiler) compileBinaryOp(op opcode, e *ast.BinaryExpr) { cl.compileExpr(e.X) cl.compileExpr(e.Y) cl.emit(op) } func (cl *compiler) compileOr(e *ast.BinaryExpr) { labelEnd := cl.newLabel() cl.compileExpr(e.X) cl.emit(opDup) cl.emitJump(opJumpTrue, labelEnd) cl.compileExpr(e.Y) cl.bindLabel(labelEnd) } func (cl *compiler) compileAnd(e *ast.BinaryExpr) { labelEnd := cl.newLabel() cl.compileExpr(e.X) cl.emit(opDup) cl.emitJump(opJumpFalse, labelEnd) cl.compileExpr(e.Y) cl.bindLabel(labelEnd) } func (cl *compiler) compileIdent(ident *ast.Ident) { tv := cl.ctx.Types.Types[ident] cv := tv.Value if cv != nil { cl.compileConstantValue(ident, cv) return } if paramIndex, ok := cl.params[ident.String()]; ok { cl.emit8(pickOp(typeIsInt(tv.Type), opPushIntParam, opPushParam), paramIndex) return } if localIndex, ok := cl.locals[ident.String()]; ok { cl.emit8(pickOp(typeIsInt(tv.Type), opPushIntLocal, opPushLocal), localIndex) return } panic(cl.errorf(ident, "can't compile a %s (type %s) variable read", ident.String(), tv.Type)) } func (cl *compiler) compileConstantValue(source ast.Expr, cv constant.Value) { switch cv.Kind() { case constant.Bool: v := constant.BoolVal(cv) if v { cl.emit(opPushTrue) } else { cl.emit(opPushFalse) } case constant.String: v := constant.StringVal(cv) id := cl.internConstant(v) cl.emit8(opPushConst, id) case constant.Int: v, exact := constant.Int64Val(cv) if !exact { panic(cl.errorf(source, "non-exact int value")) } id := cl.internIntConstant(int(v)) cl.emit8(opPushIntConst, id) case constant.Complex: panic(cl.errorf(source, "can't compile complex number constants yet")) case constant.Float: panic(cl.errorf(source, "can't compile float constants yet")) default: panic(cl.errorf(source, "unexpected constant %v", cv)) } } func (cl *compiler) internIntConstant(v int) int { if id, ok := cl.intConstantsPool[v]; ok { return id } id := len(cl.intConstants) cl.intConstants = append(cl.intConstants, v) cl.intConstantsPool[v] = id return id } func (cl *compiler) internConstant(v interface{}) int { if _, ok := v.(int); ok { panic("compiler error: int constant interned as interface{}") } if id, ok := cl.constantsPool[v]; ok { return id } id := len(cl.constants) cl.constants = append(cl.constants, v) cl.constantsPool[v] = id return id } func (cl *compiler) linkJumps() { for _, l := range cl.labels { for _, jumpPos := range l.sources { offset := l.targetPos - jumpPos patchPos := jumpPos + 1 put16(cl.code, patchPos, offset) } } } func (cl *compiler) newLabel() *label { l := &label{} cl.labels = append(cl.labels, l) return l } func (cl *compiler) bindLabel(l *label) { l.targetPos = len(cl.code) } func (cl *compiler) emit(op opcode) { cl.lastOp = op cl.code = append(cl.code, byte(op)) } func (cl *compiler) emitJump(op opcode, l *label) { l.sources = append(l.sources, len(cl.code)) cl.emit(op) cl.code = append(cl.code, 0, 0) } func (cl *compiler) emit8(op opcode, arg8 int) { cl.emit(op) cl.code = append(cl.code, byte(arg8)) } func (cl *compiler) emit16(op opcode, arg16 int) { cl.emit(op) buf := make([]byte, 2) put16(buf, 0, arg16) cl.code = append(cl.code, buf...) } func (cl *compiler) errorUnsupportedType(e ast.Node, typ types.Type, where string) compileError { return cl.errorf(e, "%s type: %s is not supported, try something simpler", where, typ) } func (cl *compiler) errorf(n ast.Node, format string, args ...interface{}) compileError { loc := cl.ctx.Fset.Position(n.Pos()) message := fmt.Sprintf("%s:%d: %s", loc.Filename, loc.Line, fmt.Sprintf(format, args...)) return compileError(message) } func (cl *compiler) isUncondJump(op opcode) bool { switch op { case opJump, opReturnFalse, opReturnTrue, opReturnTop, opReturnIntTop: return true default: return false } } func (cl *compiler) isSupportedType(typ types.Type) bool { switch typ := typ.Underlying().(type) { case *types.Pointer: // 1. Pointers to structs are supported. _, isStruct := typ.Elem().Underlying().(*types.Struct) return isStruct case *types.Basic: // 2. Some of the basic types are supported. // TODO: support byte/uint8 and maybe float64. switch typ.Kind() { case types.Bool, types.Int, types.String: return true default: return false } case *types.Interface: // 3. Interfaces are supported. return true default: return false } }