瀏覽代碼

HUE-7475 [core] Merge HUE-3102 to pysaml2 v4.4.0

(cherry picked from commit 61414bd5e6ac06dfb95e0cab735dbf3f83dfac95)
Ying Chen 8 年之前
父節點
當前提交
b5899021a2
共有 2 個文件被更改,包括 129 次插入31 次删除
  1. 18 4
      desktop/core/ext-py/pysaml2-4.4.0/src/saml2/config.py
  2. 111 27
      desktop/core/ext-py/pysaml2-4.4.0/src/saml2/sigver.py

+ 18 - 4
desktop/core/ext-py/pysaml2-4.4.0/src/saml2/config.py

@@ -28,10 +28,23 @@ __author__ = 'rolandh'
 
 
 COMMON_ARGS = [
-    "entityid", "xmlsec_binary", "debug", "key_file", "cert_file",
-    "encryption_keypairs", "additional_cert_files",
-    "metadata_key_usage", "secret", "accepted_time_diff", "name", "ca_certs",
-    "description", "valid_for", "verify_ssl_cert",
+    "entityid",
+    "xmlsec_binary",
+    "debug",
+    "key_file",
+    "key_file_passphrase",
+    "cert_file",
+    "encryption_type",
+    "encryption_keypairs",
+    "additional_cert_files",
+    "metadata_key_usage",
+    "secret",
+    "accepted_time_diff",
+    "name",
+    "ca_certs",
+    "description",
+    "valid_for",
+    "verify_ssl_cert",
     "organization",
     "contact_person",
     "name_form",
@@ -172,6 +185,7 @@ class Config(object):
         self.xmlsec_path = []
         self.debug = False
         self.key_file = None
+        self.key_file_passphrase = None
         self.cert_file = None
         self.encryption_keypairs = None
         self.additional_cert_files = None

+ 111 - 27
desktop/core/ext-py/pysaml2-4.4.0/src/saml2/sigver.py

@@ -39,7 +39,7 @@ from Cryptodome.Hash import SHA256
 from Cryptodome.Hash import SHA384
 from Cryptodome.Hash import SHA512
 
-from tempfile import NamedTemporaryFile
+from tempfile import NamedTemporaryFile, mkdtemp
 from subprocess import Popen
 from subprocess import PIPE
 
@@ -575,8 +575,8 @@ def pem_format(key):
                       key, "-----END CERTIFICATE-----"]).encode('ascii')
 
 
-def import_rsa_key_from_file(filename):
-    return RSA.importKey(read_file(filename, 'r'))
+def import_rsa_key_from_file(filename, key_file_passphrase=None):
+    return RSA.importKey(read_file(filename, 'r'), key_file_passphrase)
 
 
 def parse_xmlsec_output(output):
@@ -597,6 +597,24 @@ def parse_xmlsec_output(output):
 def sha1_digest(msg):
     return hashlib.sha1(msg).digest()
 
+# --------------------------------------------------------------------------
+
+class NamedPipe(object):
+    def __init__(self):
+        self._tempdir = mkdtemp()
+        self.name = os.path.join(self._tempdir, 'fifo')
+
+        try:
+            os.mkfifo(self.name)
+        except:
+            os.rmdir(self._tempdir)
+
+    def close(self):
+        os.remove(self.name)
+        os.rmdir(self._tempdir)
+
+# --------------------------------------------------------------------------
+
 
 class Signer(object):
     """Abstract base class for signing algorithms."""
@@ -772,7 +790,7 @@ class CryptoBackend():
                           node_xpath):
         raise NotImplementedError()
 
-    def decrypt(self, enctext, key_file):
+    def decrypt(self, enctext, key_file, passphrase=None):
         raise NotImplementedError()
 
     def sign_statement(self, statement, node_name, key_file, node_id,
@@ -884,7 +902,7 @@ class CryptoBackendXmlSec1(CryptoBackend):
 
         return output.decode('utf-8')
 
-    def decrypt(self, enctext, key_file):
+    def decrypt(self, enctext, key_file, passphrase=None):
         """
 
         :param enctext: XML document containing an encrypted part
@@ -895,16 +913,19 @@ class CryptoBackendXmlSec1(CryptoBackend):
         logger.debug("Decrypt input len: %d", len(enctext))
         _, fil = make_temp(str(enctext).encode('utf-8'), decode=False)
 
-        com_list = [self.xmlsec, "--decrypt", "--privkey-pem",
-                    key_file, "--id-attr:%s" % ID_ATTR, ENC_KEY_CLASS]
+        com_list = [self.xmlsec, "--decrypt", "--id-attr:%s" % ID_ATTR,
+                    ENC_KEY_CLASS]
 
         (_stdout, _stderr, output) = self._run_xmlsec(com_list, [fil],
                                                       exception=DecryptError,
-                                                      validate_output=False)
+                                                      validate_output=False,
+                                                      key_file=key_file,
+                                                      passphrase=passphrase)
+
         return output.decode('utf-8')
 
     def sign_statement(self, statement, node_name, key_file, node_id,
-                       id_attr):
+                       id_attr, passphrase=None):
         """
         Sign an XML statement.
 
@@ -919,18 +940,18 @@ class CryptoBackendXmlSec1(CryptoBackend):
         if isinstance(statement, SamlBase):
             statement = str(statement)
 
-        _, fil = make_temp(statement, suffix=".xml",
-                           decode=False, delete=self._xmlsec_delete_tmpfiles)
+        _, fil = make_temp("%s" % statement, suffix=".xml", decode=False,
+                           delete=self._xmlsec_delete_tmpfiles)
 
         com_list = [self.xmlsec, "--sign",
-                    "--privkey-pem", key_file,
                     "--id-attr:%s" % id_attr, node_name]
         if node_id:
             com_list.extend(["--node-id", node_id])
 
         try:
             (stdout, stderr, signed_statement) = self._run_xmlsec(
-                com_list, [fil], validate_output=False)
+                com_list, [fil], validate_output=False, key_file=key_file,
+                passphrase=passphrase)
             # this doesn't work if --store-signatures are used
             if stdout == "":
                 if signed_statement:
@@ -987,7 +1008,9 @@ class CryptoBackendXmlSec1(CryptoBackend):
         return parse_xmlsec_output(stderr)
 
     def _run_xmlsec(self, com_list, extra_args, validate_output=True,
-                    exception=XmlsecError):
+                    exception=XmlsecError,
+                    key_file=None,
+                    passphrase=None):
         """
         Common code to invoke xmlsec and parse the output.
         :param com_list: Key-value parameter list for xmlsec
@@ -999,13 +1022,40 @@ class CryptoBackendXmlSec1(CryptoBackend):
         """
         ntf = NamedTemporaryFile(suffix=".xml",
                                  delete=self._xmlsec_delete_tmpfiles)
+
         com_list.extend(["--output", ntf.name])
+
+        # Unfortunately there's no safe way to pass a password to xmlsec1.
+        # Instead, we'll decrypt the certificate and write it into a named pipe,
+        # which we'll pass to xmlsec1.
+        named_pipe = None
+        if key_file is not None:
+            if passphrase is not None:
+                named_pipe = NamedPipe()
+
+                # Decrypt the certificate, but don't write it into the FIFO
+                # until after we've started xmlsec1.
+                with open(key_file) as f:
+                    key = importKey(f.read(), passphrase=passphrase)
+
+                key_file = named_pipe.name
+
+            com_list.extend(["--privkey-pem", key_file])
+
         com_list += extra_args
 
         logger.debug("xmlsec command: %s", " ".join(com_list))
 
         pof = Popen(com_list, stderr=PIPE, stdout=PIPE)
 
+        if named_pipe is not None:
+            # Finally, write the key into our named pipe.
+            try:
+                with open(named_pipe.name, 'wb') as f:
+                    f.write(key.exportKey())
+            finally:
+                named_pipe.close()
+
         p_out = pof.stdout.read().decode('utf-8')
         p_err = pof.stderr.read().decode('utf-8')
         pof.wait()
@@ -1048,7 +1098,7 @@ class CryptoBackendXMLSecurity(CryptoBackend):
         return "XMLSecurity 0.0"
 
     def sign_statement(self, statement, node_name, key_file, node_id,
-                       _id_attr):
+                       _id_attr, passphrase=None):
         """
         Sign an XML statement.
 
@@ -1064,6 +1114,8 @@ class CryptoBackendXMLSecurity(CryptoBackend):
         import xmlsec
         import lxml.etree
 
+        assert passphrase is None, "Encrypted key files is not supported"
+
         xml = xmlsec.parse_xml(statement)
         signed = xmlsec.sign(xml, key_file)
         return lxml.etree.tostring(signed, xml_declaration=True)
@@ -1136,7 +1188,7 @@ def security_context(conf, debug=None):
         _file_name = conf.getattr("key_file", "")
         if _file_name:
             try:
-                rsa_key = import_rsa_key_from_file(_file_name)
+                rsa_key = import_rsa_key_from_file(_file_name, conf.key_file_passphrase)
             except Exception as err:
                 logger.error("Could not import key from {}: {}".format(_file_name,
                                                                        err))
@@ -1165,6 +1217,7 @@ def security_context(conf, debug=None):
         tmp_cert_file=conf.tmp_cert_file,
         tmp_key_file=conf.tmp_key_file,
         validate_certificate=conf.validate_certificate,
+        key_file_passphrase=conf.key_file_passphrase,
         enc_key_files=enc_key_files,
         encryption_keypairs=conf.encryption_keypairs,
         sec_backend=sec_backend)
@@ -1349,6 +1402,7 @@ class SecurityContext(object):
                  only_use_keys_in_metadata=False, cert_handler_extra_class=None,
                  generate_cert_info=None, tmp_cert_file=None,
                  tmp_key_file=None, validate_certificate=None,
+                 key_file_passphrase=None,
                  enc_key_files=None, enc_key_type="pem",
                  encryption_keypairs=None, enc_cert_type="pem",
                  sec_backend=None):
@@ -1362,6 +1416,7 @@ class SecurityContext(object):
 
         # Your private key for signing
         self.key_file = key_file
+        self.key_file_passphrase = key_file_passphrase
         self.key_type = key_type
 
         # Your public key for signing
@@ -1463,7 +1518,7 @@ class SecurityContext(object):
                     return _enctext
         return enctext
 
-    def decrypt(self, enctext, key_file=None):
+    def decrypt(self, enctext, key_file=None, passphrase=None):
         """ Decrypting an encrypted text by the use of a private key.
 
         :param enctext: The encrypted text as a string
@@ -1475,8 +1530,15 @@ class SecurityContext(object):
                 _enctext = self.crypto.decrypt(enctext, _enc_key_file)
                 if _enctext is not None and len(_enctext) > 0:
                     return _enctext
+
+        if key_file is None or len(key_file.strip()) == 0:
+            key_file = self.key_file
+
+        if passphrase is None:
+            passphrase = self.key_file_passphrase
+
         if key_file is not None and len(key_file.strip()) > 0:
-            _enctext = self.crypto.decrypt(enctext, key_file)
+            _enctext = self.crypto.decrypt(enctext, key_file, passphrase)
             if _enctext is not None and len(_enctext) > 0:
                 return _enctext
         return enctext
@@ -1558,11 +1620,27 @@ class SecurityContext(object):
         for _, pem_file in certs:
             try:
                 last_pem_file = pem_file
-                if self.verify_signature(decoded_xml, pem_file,
-                                         node_name=node_name,
-                                         node_id=item.id, id_attr=id_attr):
-                    verified = True
-                    break
+                if origdoc is not None:
+                    try:
+                        if self.verify_signature(origdoc, pem_file,
+                                                 node_name=node_name,
+                                                 node_id=item.id,
+                                                 id_attr=id_attr):
+                            verified = True
+                            break
+                    except Exception:
+                        if self.verify_signature(decoded_xml, pem_file,
+                                                 node_name=node_name,
+                                                 node_id=item.id,
+                                                 id_attr=id_attr):
+                            verified = True
+                            break
+                else:
+                    if self.verify_signature(decoded_xml, pem_file,
+                                             node_name=node_name,
+                                             node_id=item.id, id_attr=id_attr):
+                        verified = True
+                        break
             except XmlsecError as exc:
                 logger.error("check_sig: %s", exc)
                 pass
@@ -1773,7 +1851,8 @@ class SecurityContext(object):
         return self.sign_statement(statement, **kwargs)
 
     def sign_statement(self, statement, node_name, key=None,
-                       key_file=None, node_id=None, id_attr=""):
+                       key_file=None, node_id=None, id_attr="",
+                       passphrase=None):
         """Sign a SAML statement.
 
         :param statement: The statement to be signed
@@ -1794,8 +1873,12 @@ class SecurityContext(object):
         if not key and not key_file:
             key_file = self.key_file
 
+        if not passphrase:
+            passphrase = self.key_file_passphrase
+
         return self.crypto.sign_statement(statement, node_name, key_file,
-                                          node_id, id_attr)
+                                          node_id, id_attr,
+                                          passphrase=passphrase)
 
     def sign_assertion_using_xmlsec(self, statement, **kwargs):
         """ Deprecated function. See sign_assertion(). """
@@ -1829,7 +1912,7 @@ class SecurityContext(object):
             samlp.AttributeQuery()), **kwargs)
 
     def multiple_signatures(self, statement, to_sign, key=None, key_file=None,
-                            sign_alg=None, digest_alg=None):
+                            sign_alg=None, digest_alg=None, passphrase=None):
         """
         Sign multiple parts of a statement
 
@@ -1854,7 +1937,8 @@ class SecurityContext(object):
 
             statement = self.sign_statement(statement, class_name(item),
                                             key=key, key_file=key_file,
-                                            node_id=sid, id_attr=id_attr)
+                                            node_id=sid, id_attr=id_attr,
+                                            passphrase=passphrase)
         return statement