Browse Source

HUE-9216 [sqlalchemy] Foreign keys

Romain 6 years ago
parent
commit
e9491cf873

+ 17 - 4
desktop/libs/notebook/src/notebook/connectors/sql_alchemy.py

@@ -57,7 +57,7 @@ import textwrap
 from string import Template
 
 from django.utils.translation import ugettext as _
-from sqlalchemy import create_engine, inspect
+from sqlalchemy import create_engine, inspect, Table, MetaData
 from sqlalchemy.exc import OperationalError
 from sqlalchemy.types import NullType
 
@@ -300,17 +300,19 @@ class SqlAlchemyApi(Api):
     elif column is None:
       database = self._fix_phoenix_empty_database(database)
       columns = assist.get_columns(database, table)
-      response['columns'] = [col['name'] for col in columns]
 
+      response['columns'] = [col['name'] for col in columns]
       response['extended_columns'] = [{
           'autoincrement': col.get('autoincrement'),
           'comment': col.get('comment'),
           'default': col.get('default'),
           'name': col.get('name'),
           'nullable': col.get('nullable'),
-          'type': str(col.get('type')) if not isinstance(col.get('type'), NullType) else 'Null'
-        } for col in columns
+          'type': str(col.get('type')) if not isinstance(col.get('type'), NullType) else 'Null',
+        }
+        for col in columns
       ]
+      response['foreign_keys'] = assist.get_foreign_keys(database, table)
     else:
       columns = assist.get_columns(database, table)
       response['name'] = next((col['name'] for col in columns if column == col['name']), '')
@@ -413,3 +415,14 @@ class Assist(object):
       return result.cursor.description, result.fetchall()
     finally:
       connection.close()
+
+  def get_foreign_keys(self, database, table):
+    meta = MetaData()
+    metaTable = Table(table, meta, schema=database, autoload=True, autoload_with=self.engine)
+
+    return [{
+        'name': fk.parent.name,
+        'to': fk.target_fullname
+      }
+      for fk in metaTable.foreign_keys
+    ]

+ 35 - 1
desktop/libs/notebook/src/notebook/connectors/sql_alchemy_tests.py

@@ -28,7 +28,7 @@ from desktop.lib.django_test_util import make_logged_in_client
 from useradmin.models import User
 
 from notebook.connectors.base import AuthenticationRequired
-from notebook.connectors.sql_alchemy import SqlAlchemyApi
+from notebook.connectors.sql_alchemy import SqlAlchemyApi, Assist
 
 
 if sys.version_info[0] > 2:
@@ -300,5 +300,39 @@ class TestAutocomplete(object):
 
           data = SqlAlchemyApi(self.user, interpreter).autocomplete(snippet, database='database', table='table')
 
+          assert_equal(data['columns'], ['col1', 'col2'])
+          assert_equal([col['type'] for col in data['extended_columns']], ['string', 'Null'])
+
+  def test_get_foreign_keys(self):
+
+    interpreter = {
+      'options': {'url': 'phoenix://'}
+    }
+
+    snippet = Mock()
+    with patch('notebook.connectors.sql_alchemy.create_engine') as create_engine:
+      with patch('notebook.connectors.sql_alchemy.inspect') as inspect:
+        with patch('notebook.connectors.sql_alchemy.Assist') as Assist:
+          def col1_dict(key):
+            return {
+              'name': 'col1',
+              'type': 'string'
+            }.get(key, Mock())
+          col1 = MagicMock()
+          col1.__getitem__.side_effect = col1_dict
+          col1.get = col1_dict
+          def col2_dict(key):
+            return {
+              'name': 'col2',
+              'type': NullType()
+            }.get(key, Mock())
+          col2 = MagicMock()
+          col2.__getitem__.side_effect = col2_dict
+          col2.get = col2_dict
+
+          Assist.return_value=Mock(get_columns=Mock(return_value=[col1, col2]))
+
+
+
           assert_equal(data['columns'], ['col1', 'col2'])
           assert_equal([col['type'] for col in data['extended_columns']], ['string', 'Null'])