Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(sql_generation): handle scenario where table columns have "from" keyword in query #1600

Merged
merged 5 commits into from
Feb 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions pandasai/core/code_generation/code_cleaning.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pandasai.agent.state import AgentState
from pandasai.constants import DEFAULT_CHART_DIRECTORY
from pandasai.core.code_execution.code_executor import CodeExecutor
from pandasai.helpers.sql import extract_table_names
from pandasai.query_builders.sql_parser import SQLParser

from ...exceptions import MaliciousQueryError

Expand Down Expand Up @@ -53,10 +53,11 @@ def _clean_sql_query(self, sql_query: str) -> str:
Clean the SQL query by trimming semicolons and validating table names.
"""
sql_query = sql_query.rstrip(";")
table_names = extract_table_names(sql_query)
table_names = SQLParser.extract_table_names(sql_query)
allowed_table_names = {
df.schema.name: df.schema.name for df in self.context.dfs
} | {f'"{df.schema.name}"': df.schema.name for df in self.context.dfs}

return self._replace_table_names(sql_query, table_names, allowed_table_names)

def _validate_and_make_table_name_case_sensitive(self, node: ast.AST) -> ast.AST:
Expand Down
2 changes: 1 addition & 1 deletion pandasai/data_loader/local_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@
try:
db_manager = DuckDBConnectionManager()

if not is_sql_query_safe(query):
if not is_sql_query_safe(query, dialect="duckdb"):

Check warning on line 96 in pandasai/data_loader/local_loader.py

View check run for this annotation

Codecov / codecov/patch

pandasai/data_loader/local_loader.py#L96

Added line #L96 was not covered by tests
raise MaliciousQueryError(
"The SQL query is deemed unsafe and will not be executed."
)
Expand Down
2 changes: 1 addition & 1 deletion pandasai/data_loader/sql_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def execute_query(self, query: str, params: Optional[list] = None) -> pd.DataFra
load_function = self._get_loader_function(source_type)
query = SQLParser.transpile_sql_dialect(query, to_dialect=source_type)

if not is_sql_query_safe(query):
if not is_sql_query_safe(query, source_type):
raise MaliciousQueryError(
"The SQL query is deemed unsafe and will not be executed."
)
Expand Down
2 changes: 1 addition & 1 deletion pandasai/data_loader/view_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def execute_query(self, query: str, params: Optional[list] = None) -> pd.DataFra
load_function = self._get_loader_function(source_type)
query = SQLParser.transpile_sql_dialect(query, to_dialect=source_type)

if not is_sql_query_safe(query):
if not is_sql_query_safe(query, dialect=source_type):
raise MaliciousQueryError(
"The SQL query is deemed unsafe and will not be executed."
)
Expand Down
4 changes: 2 additions & 2 deletions pandasai/helpers/dataframe_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def __init__(self) -> None:
pass

@staticmethod
def serialize(df: "DataFrame") -> str:
def serialize(df: "DataFrame", dialect: str = "postgres") -> str:
"""
Convert df to csv like format where csv is wrapped inside <dataframe></dataframe>
Args:
Expand All @@ -18,7 +18,7 @@ def serialize(df: "DataFrame") -> str:
Returns:
str: dataframe stringify
"""
dataframe_info = f'<table table_name="{df.schema.name}"'
dataframe_info = f'<table dialect="{dialect}" table_name="{df.schema.name}"'

# Add description attribute if available
if df.schema.description is not None:
Expand Down
10 changes: 0 additions & 10 deletions pandasai/helpers/sql.py

This file was deleted.

4 changes: 2 additions & 2 deletions pandasai/helpers/sql_sanitizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def sanitize_file_name(filepath: str) -> str:
return sanitize_sql_table_name(file_name)


def is_sql_query_safe(query: str) -> bool:
def is_sql_query_safe(query: str, dialect: str = "postgres") -> bool:
try:
# List of infected keywords to block (you can add more)
infected_keywords = [
Expand Down Expand Up @@ -72,7 +72,7 @@ def is_sql_query_safe(query: str) -> bool:
temp_query = query.replace("%s", placeholder)

# Parse the query to extract its structure
parsed = sqlglot.parse_one(temp_query)
parsed = sqlglot.parse_one(temp_query, dialect=dialect)

# Ensure the main query is SELECT
if parsed.key.upper() != "SELECT":
Expand Down
25 changes: 25 additions & 0 deletions pandasai/query_builders/sql_parser.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from typing import List, Optional

import sqlglot
from sqlglot import ParseError, exp, parse_one
from sqlglot.optimizer.qualify_columns import quote_identifiers

from pandasai.exceptions import MaliciousQueryError


class SQLParser:
@staticmethod
Expand All @@ -27,6 +31,7 @@ def transform_node(node):
# Handle Table nodes
if isinstance(node, exp.Table):
original_name = node.name

if original_name in table_mapping:
alias = node.alias or original_name
mapped_value = parsed_mapping[original_name]
Expand Down Expand Up @@ -57,3 +62,23 @@ def transpile_sql_dialect(query, to_dialect, from_dialect=None):
parse_one(query, read=from_dialect) if from_dialect else parse_one(query)
)
return query.sql(dialect=to_dialect, pretty=True)

@staticmethod
def extract_table_names(sql_query: str, dialect: str = "postgres") -> List[str]:
# Parse the SQL query
parsed = sqlglot.parse(sql_query, dialect=dialect)
table_names = []
cte_names = set()

for stmt in parsed:
# Identify and store CTE names
for cte in stmt.find_all(exp.With):
for cte_expr in cte.expressions:
cte_names.add(cte_expr.alias_or_name)

# Extract table names, excluding CTEs
for node in stmt.find_all(exp.Table):
if node.name not in cte_names: # Ignore CTE names
table_names.append(node.name)

return table_names
22 changes: 14 additions & 8 deletions tests/unit_tests/core/code_generation/test_code_cleaning.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,21 +64,27 @@ def test_replace_table_names_invalid(self):
)

def test_clean_sql_query(self):
table = self.sample_df.schema.name
sql_query = f"SELECT * FROM {table};"
self.cleaner.context.dfs = [self.sample_df]
sql_query = "SELECT * FROM my_table;"
mock_dataframe = MagicMock(spec=object)
mock_dataframe.name = "my_table"
mock_dataframe.schema = MagicMock()
mock_dataframe.schema.name = "my_table"
self.cleaner.context.dfs = [mock_dataframe]
result = self.cleaner._clean_sql_query(sql_query)
self.assertEqual(result, f"SELECT * FROM {table}")
self.assertEqual(result, "SELECT * FROM my_table")

def test_validate_and_make_table_name_case_sensitive(self):
table = self.sample_df.schema.name
node = ast.Assign(
targets=[ast.Name(id="query", ctx=ast.Store())],
value=ast.Constant(value=f"SELECT * FROM {table}"),
value=ast.Constant(value="SELECT * FROM my_table"),
)
self.cleaner.context.dfs = [self.sample_df]
mock_dataframe = MagicMock(spec=object)
mock_dataframe.name = "my_table"
self.cleaner.context.dfs = [mock_dataframe]
mock_dataframe.schema = MagicMock()
mock_dataframe.schema.name = "my_table"
updated_node = self.cleaner._validate_and_make_table_name_case_sensitive(node)
self.assertEqual(updated_node.value.value, f"SELECT * FROM {table}")
self.assertEqual(updated_node.value.value, "SELECT * FROM my_table")

def test_extract_fix_dataframe_redeclarations(self):
node = ast.Assign(
Expand Down
4 changes: 2 additions & 2 deletions tests/unit_tests/data_loader/test_sql_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def test_mysql_malicious_query(self, mysql_schema):
with pytest.raises(MaliciousQueryError):
loader.execute_query("DROP TABLE users")

mock_sql_query.assert_called_once_with("DROP TABLE users")
mock_sql_query.assert_called_once_with("DROP TABLE users", "mysql")

def test_mysql_safe_query(self, mysql_schema):
"""Test loading data from a MySQL source creates a VirtualDataFrame and handles queries correctly."""
Expand Down Expand Up @@ -153,7 +153,7 @@ def test_mysql_safe_query(self, mysql_schema):
result = loader.execute_query("SELECT * FROM users")

assert isinstance(result, DataFrame)
mock_sql_query.assert_called_once_with("SELECT\n *\nFROM users")
mock_sql_query.assert_called_once_with("SELECT\n *\nFROM users", "mysql")

def test_mysql_malicious_with_no_import(self, mysql_schema):
"""Test loading data from a MySQL source creates a VirtualDataFrame and handles queries correctly."""
Expand Down
15 changes: 14 additions & 1 deletion tests/unit_tests/helpers/test_dataframe_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,20 @@ def test_serialize_with_name_and_description(self, sample_df):
"""Test serialization with name and description attributes."""

result = DataframeSerializer.serialize(sample_df)
expected = """<table table_name="table_6c30b42101939c7bdf95f4c1052d615c" dimensions="3x2">
expected = """<table dialect="postgres" table_name="table_6c30b42101939c7bdf95f4c1052d615c" dimensions="3x2">
A,B
1,4
2,5
3,6
</table>
"""
assert result.replace("\r\n", "\n") == expected.replace("\r\n", "\n")

def test_serialize_with_name_and_description_with_dialect(self, sample_df):
"""Test serialization with name and description attributes."""

result = DataframeSerializer.serialize(sample_df, dialect="mysql")
expected = """<table dialect="mysql" table_name="table_6c30b42101939c7bdf95f4c1052d615c" dimensions="3x2">
A,B
1,4
2,5
Expand Down
2 changes: 1 addition & 1 deletion tests/unit_tests/prompts/test_sql_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def test_str_with_args(self, output_type, output_type_template):
prompt_content
== f'''<tables>

<table table_name="table_d41d8cd98f00b204e9800998ecf8427e" dimensions="0x0">
<table dialect="postgres" table_name="table_d41d8cd98f00b204e9800998ecf8427e" dimensions="0x0">

</table>

Expand Down
72 changes: 72 additions & 0 deletions tests/unit_tests/query_builders/test_sql_parser.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest

from pandasai.exceptions import MaliciousQueryError
from pandasai.query_builders.sql_parser import SQLParser


Expand Down Expand Up @@ -51,6 +52,29 @@ class TestSqlParser:
) AS o
ON "c"."id" = "o"."customer_id"''',
),
(
"""SELECT d.name AS department, hse.name AS employee, hse.salary
FROM (
SELECT * FROM employees WHERE salary > 50000
) AS hse
JOIN departments d ON hse.dept_id = d.id;
""",
{"employees": "employee", "departments": "department"},
"""SELECT
"d"."name" AS "department",
"hse"."name" AS "employee",
"hse"."salary"
FROM (
SELECT
*
FROM "employee" AS employees
WHERE
"salary" > 50000
) AS "hse"
JOIN "department" AS d
ON "hse"."dept_id" = "d"."id"
""",
),
],
)
def test_replace_table_names(query, table_mapping, expected):
Expand All @@ -62,3 +86,51 @@ def test_mysql_transpilation(self):
expected = """SELECT\n COUNT(*) AS `total_rows`"""
result = SQLParser.transpile_sql_dialect(query, to_dialect="mysql")
assert result.strip() == expected.strip()

@staticmethod
@pytest.mark.parametrize(
"sql_query, dialect, expected_tables",
[
# 1. Simple SELECT query
("SELECT * FROM users;", "postgres", ["users"]),
# 2. Query with INNER JOIN
(
"SELECT * FROM users u JOIN orders o ON u.id = o.user_id;",
"postgres",
["users", "orders"],
),
# 3. Query with LEFT JOIN
(
"SELECT * FROM customers c LEFT JOIN orders o ON c.id = o.customer_id;",
"postgres",
["customers", "orders"],
),
# 4. Subquery
(
"SELECT * FROM (SELECT * FROM employees) AS e;",
"postgres",
["employees"],
),
# 5. CTE (Common Table Expression)
(
"""
WITH sales_data AS (SELECT * FROM sales)
SELECT * FROM sales_data;
""",
"postgres",
["sales"],
),
# 6. Table with alias (should return original table name)
("SELECT u.name FROM users AS u;", "postgres", ["users"]),
# 7. Schema-prefixed table
("SELECT * FROM sales.customers;", "postgres", ["customers"]),
# 8. Quoted table names (double quotes for PostgreSQL, backticks for MySQL)
('SELECT * FROM "Order Details";', "postgres", ["Order Details"]),
# ("SELECT * FROM `Order Details`;", "mysql", ["Order Details"]),
# 11. Edge Case: Invalid Query (should return empty list instead of raising an error)
("SELECT *", "postgres", []),
],
)
def test_extract_table_names(sql_query, dialect, expected_tables):
result = SQLParser.extract_table_names(sql_query, dialect)
assert SQLParser.extract_table_names(sql_query, dialect) == expected_tables
Loading