diff --git a/src/lox/ASTPrinter.java b/src/lox/ASTPrinter.java index 713d31b..4d853c4 100644 --- a/src/lox/ASTPrinter.java +++ b/src/lox/ASTPrinter.java @@ -33,6 +33,11 @@ class ASTPrinter implements Expr.Visitor { return expr.name.toString(); } + @Override + public String visitLogicalExpr(Expr.Logical expr) { + return expr.left.accept(this) + expr.operator.toString() + expr.right.accept(this); + } + @Override public String visitAssignExpr(Expr.Assign expr) { return expr.name.toString(); diff --git a/src/lox/Expr.java b/src/lox/Expr.java index fe4dc2e..65a7b97 100644 --- a/src/lox/Expr.java +++ b/src/lox/Expr.java @@ -8,6 +8,7 @@ abstract class Expr { R visitBinaryExpr(Binary expr); R visitGroupingExpr(Grouping expr); R visitLiteralExpr(Literal expr); + R visitLogicalExpr(Logical expr); R visitUnaryExpr(Unary expr); R visitVariableExpr(Variable expr); } @@ -70,6 +71,23 @@ abstract class Expr { final Object value; } + static class Logical extends Expr { + Logical(Expr left, Token operator, Expr right) { + this.left = left; + this.operator = operator; + this.right = right; + } + + @Override + R accept(Visitor visitor) { + return visitor.visitLogicalExpr(this); + } + + final Expr left; + final Token operator; + final Expr right; + } + static class Unary extends Expr { Unary(Token operator, Expr right) { this.operator = operator; diff --git a/src/lox/Interpreter.java b/src/lox/Interpreter.java index 8cd9548..aa6609e 100644 --- a/src/lox/Interpreter.java +++ b/src/lox/Interpreter.java @@ -20,6 +20,23 @@ class Interpreter implements Expr.Visitor, Stmt.Visitor { return expr.value; } + @Override + public Object visitLogicalExpr(Expr.Logical expr) { + Object left = evaluate(expr.left); + + if (expr.operator.type == TokenType.OR) { + if (isTruthy(left)) { + return left; + } else { + if (!isTruthy(left)) { + return left; + } + } + } + + return evaluate(expr.right); + } + @Override public Object visitGroupingExpr(Expr.Grouping expr) { return evaluate(expr.expression); @@ -102,6 +119,16 @@ class Interpreter implements Expr.Visitor, Stmt.Visitor { return null; } + @Override + public Void visitIfStmt(Stmt.If stmt) { + if (isTruthy(evaluate(stmt.condition))) { + execute(stmt.thenBranch); + } else { + execute(stmt.elseBranch); + } + return null; + } + @Override public Void visitPrintStmt(Stmt.Print stmt) { Object value = evaluate(stmt.expression); @@ -119,6 +146,14 @@ class Interpreter implements Expr.Visitor, Stmt.Visitor { return null; } + @Override + public Void visitWhileStmt(Stmt.While stmt) { + while (isTruthy(stmt.condition)) { + execute(stmt.body); + } + return null; + } + @Override public Void visitBlockStmt(Stmt.Block stmt) { executeBlock(stmt.statements, new Environment(environment)); diff --git a/src/lox/Parser.java b/src/lox/Parser.java index e8cbefc..154a5b7 100644 --- a/src/lox/Parser.java +++ b/src/lox/Parser.java @@ -1,6 +1,7 @@ package lox; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import static lox.TokenType.*; @@ -47,9 +48,18 @@ class Parser { } private Stmt statement() { + if (match(FOR)) { + return forStatement(); + } + if (match(IF)) { + return ifStatement(); + } if (match(PRINT)) { return printStatement(); } + if (match(WHILE)) { + return whileStatement(); + } if (match(LEFT_BRACE)) { return new Stmt.Block(block()); } @@ -62,12 +72,78 @@ class Parser { return new Stmt.Print(value); } + private Stmt whileStatement() { + consume(LEFT_PAREN, "Expect '(' after 'while"); + Expr condition = expression(); + consume(RIGHT_PAREN, "Expect ')' after while condition"); + Stmt body = statement(); + + return new Stmt.While(condition, body); + } + private Stmt expressionStatement() { Expr value = expression(); consume(SEMICOLON, "Expect ';' after value."); return new Stmt.Expression(value); } + private Stmt ifStatement() { + consume(LEFT_PAREN, "Expect '(' after 'if'."); + Expr condition = expression(); + consume(RIGHT_PAREN, "Expect ')' after if condition"); + + Stmt thenBranch = statement(); + Stmt elseBranch = null; + if (match(ELSE)) { + elseBranch = statement(); + } + return new Stmt.If(condition, thenBranch, elseBranch); + } + + private Stmt forStatement() { + consume(LEFT_PAREN, "Expect '(' after 'for'."); + + Stmt initializer; + if (match(SEMICOLON)) { + initializer = null; + } else if (match(VAR)) { + initializer = varDeclaration(); + } else { + initializer = expressionStatement(); + } + + Expr condition = null; + if (!check(SEMICOLON)) { + condition = expression(); + } + consume(SEMICOLON, "Expect ';' after loop condition."); + + Expr increment = null; + if (!check(RIGHT_PAREN)) { + increment = expression(); + } + consume(RIGHT_PAREN, "Expect ')' after for clauses"); + Stmt body = statement(); + + if (increment != null) { + // move incrementer into the body + body = new Stmt.Block(Arrays.asList(body, new Stmt.Expression(increment))); + } + + if (condition == null) { + // set an infinite loop up if no condition given + condition = new Expr.Literal(true); + } + body = new Stmt.While(condition, body); + + if (initializer != null) { + // insert the initializer before the while statement + body = new Stmt.Block(Arrays.asList(initializer, body)); + } + + return body; + } + private List block() { List statements = new ArrayList<>(); @@ -79,7 +155,7 @@ class Parser { } private Expr assignment() { - Expr expr = equality(); + Expr expr = or(); if (match(EQUAL)) { Token equals = previous(); @@ -95,6 +171,30 @@ class Parser { return expr; } + private Expr or() { + Expr expr = and(); + + while (match(OR)) { + Token operator = previous(); + Expr right = and(); + expr = new Expr.Logical(expr, operator, right); + } + + return expr; + } + + private Expr and() { + Expr expr = equality(); + + while (match(AND)) { + Token operator = previous(); + Expr right = equality(); + expr = new Expr.Logical(expr, operator, right); + } + + return expr; + } + private Expr expression() { return assignment(); } diff --git a/src/lox/Stmt.java b/src/lox/Stmt.java index 2dc0f68..f097da1 100644 --- a/src/lox/Stmt.java +++ b/src/lox/Stmt.java @@ -6,8 +6,10 @@ abstract class Stmt { interface Visitor { R visitBlockStmt(Block stmt); R visitExpressionStmt(Expression stmt); + R visitIfStmt(If stmt); R visitVarStmt(Var stmt); R visitPrintStmt(Print stmt); + R visitWhileStmt(While stmt); } static class Block extends Stmt { @@ -36,6 +38,23 @@ abstract class Stmt { final Expr expression; } + static class If extends Stmt { + If(Expr condition, Stmt thenBranch, Stmt elseBranch) { + this.condition = condition; + this.thenBranch = thenBranch; + this.elseBranch = elseBranch; + } + + @Override + R accept(Visitor visitor) { + return visitor.visitIfStmt(this); + } + + final Expr condition; + final Stmt thenBranch; + final Stmt elseBranch; + } + static class Var extends Stmt { Var(Token name, Expr initializer) { this.name = name; @@ -64,5 +83,20 @@ abstract class Stmt { final Expr expression; } + static class While extends Stmt { + While(Expr condition, Stmt body) { + this.condition = condition; + this.body = body; + } + + @Override + R accept(Visitor visitor) { + return visitor.visitWhileStmt(this); + } + + final Expr condition; + final Stmt body; + } + abstract R accept(Visitor visitor); } diff --git a/src/tool/GenerateAST.java b/src/tool/GenerateAST.java index 03ea38d..b91e423 100644 --- a/src/tool/GenerateAST.java +++ b/src/tool/GenerateAST.java @@ -22,6 +22,7 @@ public class GenerateAST { "Binary : Expr left, Token operator, Expr right", "Grouping : Expr expression", "Literal : Object value", + "Logical : Expr left, Token operator, Expr right", "Unary : Token operator, Expr right", "Variable : Token name")); @@ -31,8 +32,10 @@ public class GenerateAST { Arrays.asList( "Block : List statements", "Expression : Expr expression", + "If : Expr condition, Stmt thenBranch, Stmt elseBranch", "Var : Token name, Expr initializer", - "Print : Expr expression")); + "Print : Expr expression", + "While : Expr condition, Stmt body")); } private static void defineAST(String outputDir, String baseName, List types)