|
|
@@ -85,14 +85,14 @@ class TrinoApi(Api):
|
|
|
else DEFAULT_AUTH_PASSWORD.get()
|
|
|
)
|
|
|
|
|
|
- def _format_identifier(self, identifier):
|
|
|
+ def _format_identifier(self, identifier, is_db=False):
|
|
|
# Remove any backticks
|
|
|
identifier = identifier.replace('`', '')
|
|
|
|
|
|
# Check if already formatted
|
|
|
if not (identifier.startswith('"') and identifier.endswith('"')):
|
|
|
# Check if it's a multi-part identifier (e.g., catalog.schema)
|
|
|
- if '.' in identifier:
|
|
|
+ if '.' in identifier and is_db:
|
|
|
# Split and format each part separately
|
|
|
identifier = '"{}"'.format('"."'.join(identifier.split('.')))
|
|
|
else:
|
|
|
@@ -113,7 +113,7 @@ class TrinoApi(Api):
|
|
|
@query_error_handler
|
|
|
def execute(self, notebook, snippet):
|
|
|
database = snippet['database']
|
|
|
- database = self._format_identifier(database)
|
|
|
+ database = self._format_identifier(database, is_db=True)
|
|
|
query_client = TrinoQuery(self.trino_request, 'USE ' + database)
|
|
|
query_client.execute()
|
|
|
|
|
|
@@ -258,7 +258,7 @@ class TrinoApi(Api):
|
|
|
if operation == 'hello':
|
|
|
statement = "SELECT 'Hello World!'"
|
|
|
else:
|
|
|
- database = self._format_identifier(database)
|
|
|
+ database = self._format_identifier(database, is_db=True)
|
|
|
table = self._format_identifier(table)
|
|
|
column = '%(column)s' % {'column': self._format_identifier(column)} if column else '*'
|
|
|
statement = textwrap.dedent('''\
|
|
|
@@ -312,7 +312,7 @@ class TrinoApi(Api):
|
|
|
return catalogs
|
|
|
|
|
|
def _show_tables(self, database):
|
|
|
- database = self._format_identifier(database)
|
|
|
+ database = self._format_identifier(database, is_db=True)
|
|
|
query_client = TrinoQuery(self.trino_request, 'USE ' + database)
|
|
|
query_client.execute()
|
|
|
query_client = TrinoQuery(self.trino_request, 'SHOW TABLES')
|
|
|
@@ -327,7 +327,7 @@ class TrinoApi(Api):
|
|
|
]
|
|
|
|
|
|
def _get_columns(self, database, table):
|
|
|
- database = self._format_identifier(database)
|
|
|
+ database = self._format_identifier(database, is_db=True)
|
|
|
query_client = TrinoQuery(self.trino_request, 'USE ' + database)
|
|
|
query_client.execute()
|
|
|
table = self._format_identifier(table)
|
|
|
@@ -356,7 +356,7 @@ class TrinoApi(Api):
|
|
|
if statement:
|
|
|
try:
|
|
|
database = snippet['database']
|
|
|
- database = self._format_identifier(database)
|
|
|
+ database = self._format_identifier(database, is_db=True)
|
|
|
TrinoQuery(self.trino_request, 'USE ' + database).execute()
|
|
|
result = TrinoQuery(self.trino_request, 'EXPLAIN ' + statement).execute()
|
|
|
explanation = result.rows
|