diff --git a/src/lox/ASTPrinter.java b/src/lox/ASTPrinter.java index 4d853c4..18587e8 100644 --- a/src/lox/ASTPrinter.java +++ b/src/lox/ASTPrinter.java @@ -28,6 +28,19 @@ class ASTPrinter implements Expr.Visitor { return parenthesize(expr.operator.lexeme, expr.right); } + @Override + public String visitCallExpr(Expr.Call expr) { + StringBuilder builder = new StringBuilder(); + + builder.append(expr.callee.toString()).append("("); + for (Expr arg : expr.arguments) { + builder.append(arg.accept(this)).append(", "); + } + builder.append(")"); + + return builder.toString(); + } + @Override public String visitVariableExpr(Expr.Variable expr) { return expr.name.toString(); diff --git a/src/lox/Expr.java b/src/lox/Expr.java index 65a7b97..5c0d816 100644 --- a/src/lox/Expr.java +++ b/src/lox/Expr.java @@ -6,6 +6,7 @@ abstract class Expr { interface Visitor { R visitAssignExpr(Assign expr); R visitBinaryExpr(Binary expr); + R visitCallExpr(Call expr); R visitGroupingExpr(Grouping expr); R visitLiteralExpr(Literal expr); R visitLogicalExpr(Logical expr); @@ -45,6 +46,23 @@ abstract class Expr { final Expr right; } + static class Call extends Expr { + Call(Expr callee, Token paren, List arguments) { + this.callee = callee; + this.paren = paren; + this.arguments = arguments; + } + + @Override + R accept(Visitor visitor) { + return visitor.visitCallExpr(this); + } + + final Expr callee; + final Token paren; + final List arguments; + } + static class Grouping extends Expr { Grouping(Expr expression) { this.expression = expression; diff --git a/src/lox/Interpreter.java b/src/lox/Interpreter.java index aa6609e..6c203f8 100644 --- a/src/lox/Interpreter.java +++ b/src/lox/Interpreter.java @@ -1,9 +1,32 @@ package lox; +import java.util.ArrayList; import java.util.List; class Interpreter implements Expr.Visitor, Stmt.Visitor { - private Environment environment = new Environment(); + final Environment globals = new Environment(); + private Environment environment = globals; + + Interpreter() { + globals.define( + "clock", + new LoxCallable() { + @Override + public int arity() { + return 0; + } + + @Override + public Object call(Interpreter interpreter, List arguments) { + return (double) System.currentTimeMillis() / 1000.0; + } + + @Override + public String toString() { + return ""; + } + }); + } void interpret(List statements) { try { @@ -57,6 +80,27 @@ class Interpreter implements Expr.Visitor, Stmt.Visitor { return null; } + @Override + public Object visitCallExpr(Expr.Call expr) { + Object callee = evaluate(expr.callee); + + List arguments = new ArrayList<>(); + for (Expr argument : expr.arguments) { + arguments.add(evaluate(argument)); + } + + if (!(callee instanceof LoxCallable)) { + throw new RuntimeError(expr.paren, "Can only call functions and classes"); + } + LoxCallable function = (LoxCallable) callee; + if (arguments.size() != function.arity()) { + throw new RuntimeError( + expr.paren, + "Expected " + function.arity() + " arguments but called with " + arguments.size() + "."); + } + return function.call(this, arguments); + } + @Override public Object visitBinaryExpr(Expr.Binary expr) { Object left = evaluate(expr.left); @@ -119,6 +163,13 @@ class Interpreter implements Expr.Visitor, Stmt.Visitor { return null; } + @Override + public Void visitFunctionStmt(Stmt.Function stmt) { + LoxFunction function = new LoxFunction(stmt, environment); + environment.define(stmt.name.lexeme, function); + return null; + } + @Override public Void visitIfStmt(Stmt.If stmt) { if (isTruthy(evaluate(stmt.condition))) { @@ -136,6 +187,15 @@ class Interpreter implements Expr.Visitor, Stmt.Visitor { return null; } + @Override + public Void visitReturnStmt(Stmt.Return stmt) { + Object value = null; + if (stmt.value != null) { + value = evaluate(stmt.value); + } + throw new Return(value); + } + @Override public Void visitVarStmt(Stmt.Var stmt) { Object value = null; diff --git a/src/lox/LoxCallable.java b/src/lox/LoxCallable.java new file mode 100644 index 0000000..f413593 --- /dev/null +++ b/src/lox/LoxCallable.java @@ -0,0 +1,9 @@ +package lox; + +import java.util.List; + +interface LoxCallable { + int arity(); + + Object call(Interpreter interpreter, List arguments); +} diff --git a/src/lox/LoxFunction.java b/src/lox/LoxFunction.java new file mode 100644 index 0000000..4b188ad --- /dev/null +++ b/src/lox/LoxFunction.java @@ -0,0 +1,37 @@ +package lox; + +import java.util.List; + +class LoxFunction implements LoxCallable { + private final Stmt.Function declaration; + private final Environment closure; + + LoxFunction(Stmt.Function declaration, Environment closure) { + this.closure = closure; + this.declaration = declaration; + } + + @Override + public Object call(Interpreter interpreter, List arguments) { + Environment environment = new Environment(closure); + for (int i = 0; i < declaration.params.size(); i++) { + environment.define(declaration.params.get(i).lexeme, arguments.get(i)); + } + try { + interpreter.executeBlock(declaration.body, environment); + } catch (Return returnValue) { + return returnValue.value; + } + return null; + } + + @Override + public int arity() { + return declaration.params.size(); + } + + @Override + public String toString() { + return ""; + } +} diff --git a/src/lox/Parser.java b/src/lox/Parser.java index 154a5b7..1d7e773 100644 --- a/src/lox/Parser.java +++ b/src/lox/Parser.java @@ -26,6 +26,9 @@ class Parser { private Stmt declaration() { try { + if (match(FUN)) { + return function("function"); + } if (match(VAR)) { return varDeclaration(); } @@ -57,6 +60,9 @@ class Parser { if (match(PRINT)) { return printStatement(); } + if (match(RETURN)) { + return returnStatement(); + } if (match(WHILE)) { return whileStatement(); } @@ -72,6 +78,16 @@ class Parser { return new Stmt.Print(value); } + private Stmt returnStatement() { + Token keyword = previous(); + Expr value = null; + if (!check(SEMICOLON)) { + value = expression(); + } + consume(SEMICOLON, "Expect ';' after return value"); + return new Stmt.Return(keyword, value); + } + private Stmt whileStatement() { consume(LEFT_PAREN, "Expect '(' after 'while"); Expr condition = expression(); @@ -87,6 +103,24 @@ class Parser { return new Stmt.Expression(value); } + private Stmt function(String kind) { + Token name = consume(IDENTIFIER, "Expect " + kind + " name."); + consume(LEFT_PAREN, "Expect '(' after " + kind + "name"); + List parameters = new ArrayList<>(); + if (!check(RIGHT_PAREN)) { + do { + if (parameters.size() >= 255) { + error(peek(), "Can't have more than 255 parameters."); + } + parameters.add(consume(IDENTIFIER, "Expect parameter name.")); + } while (match(COMMA)); + } + consume(RIGHT_PAREN, "Expect ')' after parameters."); + consume(LEFT_BRACE, "Expect '{' before " + kind + " body."); + List body = block(); + return new Stmt.Function(name, parameters, body); + } + private Stmt ifStatement() { consume(LEFT_PAREN, "Expect '(' after 'if'."); Expr condition = expression(); @@ -245,7 +279,38 @@ class Parser { Expr right = unary(); return new Expr.Unary(operator, right); } - return primary(); + return call(); + } + + private Expr call() { + Expr expr = primary(); + + while (true) { + if (match(LEFT_PAREN)) { + expr = finishCall(expr); + } else { + break; + } + } + + return expr; + } + + private Expr finishCall(Expr callee) { + List arguments = new ArrayList<>(); + + if (!check(RIGHT_PAREN)) { + do { + arguments.add(expression()); + } while (match(COMMA)); + } + Token paren = consume(RIGHT_PAREN, "Expected ')' after argument list"); + + if (arguments.size() >= 255) { + error(peek(), "Can't have more than 255 arguments"); + } + + return new Expr.Call(callee, paren, arguments); } private Expr primary() { diff --git a/src/lox/Return.java b/src/lox/Return.java new file mode 100644 index 0000000..d8338fb --- /dev/null +++ b/src/lox/Return.java @@ -0,0 +1,10 @@ +package lox; + +class Return extends RuntimeException{ + final Object value; + + Return(Object value) { + super(null, null, false, false); + this.value = value; + } +} diff --git a/src/lox/Stmt.java b/src/lox/Stmt.java index f097da1..3d0b5ac 100644 --- a/src/lox/Stmt.java +++ b/src/lox/Stmt.java @@ -6,9 +6,11 @@ abstract class Stmt { interface Visitor { R visitBlockStmt(Block stmt); R visitExpressionStmt(Expression stmt); + R visitFunctionStmt(Function stmt); R visitIfStmt(If stmt); R visitVarStmt(Var stmt); R visitPrintStmt(Print stmt); + R visitReturnStmt(Return stmt); R visitWhileStmt(While stmt); } @@ -38,6 +40,23 @@ abstract class Stmt { final Expr expression; } + static class Function extends Stmt { + Function(Token name, List params, List body) { + this.name = name; + this.params = params; + this.body = body; + } + + @Override + R accept(Visitor visitor) { + return visitor.visitFunctionStmt(this); + } + + final Token name; + final List params; + final List body; + } + static class If extends Stmt { If(Expr condition, Stmt thenBranch, Stmt elseBranch) { this.condition = condition; @@ -83,6 +102,21 @@ abstract class Stmt { final Expr expression; } + static class Return extends Stmt { + Return(Token keyword, Expr value) { + this.keyword = keyword; + this.value = value; + } + + @Override + R accept(Visitor visitor) { + return visitor.visitReturnStmt(this); + } + + final Token keyword; + final Expr value; + } + static class While extends Stmt { While(Expr condition, Stmt body) { this.condition = condition; diff --git a/src/tool/GenerateAST.java b/src/tool/GenerateAST.java index b91e423..dc127c2 100644 --- a/src/tool/GenerateAST.java +++ b/src/tool/GenerateAST.java @@ -20,6 +20,7 @@ public class GenerateAST { Arrays.asList( "Assign : Token name, Expr value", "Binary : Expr left, Token operator, Expr right", + "Call : Expr callee, Token paren, List arguments", "Grouping : Expr expression", "Literal : Object value", "Logical : Expr left, Token operator, Expr right", @@ -32,9 +33,11 @@ public class GenerateAST { Arrays.asList( "Block : List statements", "Expression : Expr expression", + "Function : Token name, List params, List body", "If : Expr condition, Stmt thenBranch, Stmt elseBranch", "Var : Token name, Expr initializer", "Print : Expr expression", + "Return : Token keyword, Expr value", "While : Expr condition, Stmt body")); }