diff --git a/src/sql/Expr.cpp b/src/sql/Expr.cpp index c2103a3..6804fc5 100644 --- a/src/sql/Expr.cpp +++ b/src/sql/Expr.cpp @@ -218,7 +218,7 @@ Expr* Expr::makeInOperator(Expr* expr, SelectStatement* select) { } Expr* Expr::makeExtract(DatetimeField datetimeField, Expr* expr) { - Expr* e = new Expr(kExprFunctionRef); + Expr* e = new Expr(kExprExtract); e->name = strdup("EXTRACT"); e->datetimeField = datetimeField; e->expr = expr; @@ -226,7 +226,7 @@ Expr* Expr::makeExtract(DatetimeField datetimeField, Expr* expr) { } Expr* Expr::makeCast(Expr* expr, ColumnType columnType) { - Expr* e = new Expr(kExprFunctionRef); + Expr* e = new Expr(kExprCast); e->name = strdup("CAST"); e->columnType = columnType; e->expr = expr; diff --git a/src/sql/Expr.h b/src/sql/Expr.h index 6939e3f..cb5bddb 100644 --- a/src/sql/Expr.h +++ b/src/sql/Expr.h @@ -27,7 +27,8 @@ enum ExprType { kExprHint, kExprArray, kExprArrayIndex, - kExprDatetimeField + kExprExtract, + kExprCast }; // Operator types. These are important for expressions of type kExprOperator. diff --git a/src/util/sqlhelper.cpp b/src/util/sqlhelper.cpp index 74a5b4d..e31ff78 100755 --- a/src/util/sqlhelper.cpp +++ b/src/util/sqlhelper.cpp @@ -10,6 +10,7 @@ namespace hsql { void printAlias(Alias* alias, uintmax_t numIndent); std::ostream& operator<<(std::ostream& os, const OperatorType& op); + std::ostream& operator<<(std::ostream& os, const DatetimeField& op); std::string indent(uintmax_t numIndent) { return std::string(numIndent, '\t'); @@ -32,6 +33,12 @@ namespace hsql { void inprint(const OperatorType& op, uintmax_t numIndent) { std::cout << indent(numIndent) << op << std::endl; } + void inprint(const ColumnType& colType, uintmax_t numIndent) { + std::cout << indent(numIndent) << colType << std::endl; + } + void inprint(const DatetimeField& colType, uintmax_t numIndent) { + std::cout << indent(numIndent) << colType << std::endl; + } void printTableRefInfo(TableRef* table, uintmax_t numIndent) { switch (table->type) { @@ -118,6 +125,16 @@ namespace hsql { inprint(expr->name, numIndent); for (Expr* e : *expr->exprList) printExpression(e, numIndent + 1); break; + case kExprExtract: + inprint(expr->name, numIndent); + inprint(expr->datetimeField, numIndent + 1); + printExpression(expr->expr, numIndent + 1); + break; + case kExprCast: + inprint(expr->name, numIndent); + inprint(expr->columnType, numIndent + 1); + printExpression(expr->expr, numIndent + 1); + break; case kExprOperator: printOperatorExpression(expr, numIndent); break; @@ -372,4 +389,23 @@ namespace hsql { } } + std::ostream& operator<<(std::ostream& os, const DatetimeField& datetime) { + static const std::map operatorToToken = { + {kDatetimeNone, "None"}, + {kDatetimeSecond, "SECOND"}, + {kDatetimeMinute, "MINUTE"}, + {kDatetimeHour, "HOUR"}, + {kDatetimeDay, "DAY"}, + {kDatetimeMonth, "MONTH"}, + {kDatetimeYear, "YEAR"} + }; + + const auto found = operatorToToken.find(datetime); + if (found == operatorToToken.cend()) { + return os << static_cast(datetime); + } else { + return os << (*found).second; + } + } + } // namespace hsql diff --git a/test/select_tests.cpp b/test/select_tests.cpp index 3093ceb..6833770 100644 --- a/test/select_tests.cpp +++ b/test/select_tests.cpp @@ -638,7 +638,7 @@ TEST(Extract) { stmt = (SelectStatement*) result.getStatement(0); ASSERT_TRUE(stmt->selectList); ASSERT_EQ(stmt->selectList->size(), 1u); - ASSERT_EQ(stmt->selectList->at(0)->type, kExprFunctionRef); + ASSERT_EQ(stmt->selectList->at(0)->type, kExprExtract); ASSERT_EQ(stmt->selectList->at(0)->name, std::string("EXTRACT")); ASSERT_EQ(stmt->selectList->at(0)->datetimeField, kDatetimeYear); ASSERT_TRUE(stmt->selectList->at(0)->expr); @@ -647,7 +647,7 @@ TEST(Extract) { stmt = (SelectStatement*) result.getStatement(1); ASSERT_TRUE(stmt->selectList); ASSERT_EQ(stmt->selectList->size(), 2u); - ASSERT_EQ(stmt->selectList->at(1)->type, kExprFunctionRef); + ASSERT_EQ(stmt->selectList->at(1)->type, kExprExtract); ASSERT_EQ(stmt->selectList->at(1)->name, std::string("EXTRACT")); ASSERT_EQ(stmt->selectList->at(1)->datetimeField, kDatetimeMonth); ASSERT_TRUE(stmt->selectList->at(1)->expr); @@ -658,11 +658,31 @@ TEST(Extract) { stmt = (SelectStatement*) result.getStatement(2); ASSERT_TRUE(stmt->whereClause); ASSERT_TRUE(stmt->whereClause->expr); - ASSERT_EQ(stmt->whereClause->expr->type, kExprFunctionRef); + ASSERT_EQ(stmt->whereClause->expr->type, kExprExtract); ASSERT_EQ(stmt->whereClause->expr->name, std::string("EXTRACT")); ASSERT_EQ(stmt->whereClause->expr->datetimeField, kDatetimeMinute); } +TEST(CastExpression) { + TEST_PARSE_SINGLE_SQL( + "SELECT CAST(10 AS INT);", + kStmtSelect, + SelectStatement, + result, + stmt); + + ASSERT_TRUE(stmt->selectList); + ASSERT_FALSE(stmt->fromTable); + ASSERT_FALSE(stmt->whereClause); + ASSERT_FALSE(stmt->groupBy); + + ASSERT_EQ(stmt->selectList->size(), 1u); + ASSERT_EQ(stmt->selectList->at(0)->type, kExprCast); + ASSERT_EQ(stmt->selectList->at(0)->name, std::string("CAST")); + ASSERT_EQ(stmt->selectList->at(0)->columnType, ColumnType(DataType::INT)); + ASSERT_EQ(stmt->selectList->at(0)->expr->type, kExprLiteralInt); +} + TEST(NoFromClause) { TEST_PARSE_SINGLE_SQL( "SELECT 1 + 2;",