s3: cache client instance (#34372)
This commit is contained in:
		@@ -44,7 +44,7 @@ def __getattr__(self, key):
 | 
			
		||||
 | 
			
		||||
def _s3_open(url):
 | 
			
		||||
    parsed = url_util.parse(url)
 | 
			
		||||
    s3 = s3_util.create_s3_session(parsed, connection=s3_util.get_mirror_connection(parsed))
 | 
			
		||||
    s3 = s3_util.get_s3_session(url, method="fetch")
 | 
			
		||||
 | 
			
		||||
    bucket = parsed.netloc
 | 
			
		||||
    key = parsed.path
 | 
			
		||||
 
 | 
			
		||||
@@ -12,6 +12,7 @@
 | 
			
		||||
import llnl.util.tty as tty
 | 
			
		||||
 | 
			
		||||
import spack.config
 | 
			
		||||
import spack.mirror
 | 
			
		||||
import spack.paths
 | 
			
		||||
import spack.util.s3
 | 
			
		||||
import spack.util.web
 | 
			
		||||
@@ -246,14 +247,24 @@ def get_object(self, Bucket=None, Key=None):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_gather_s3_information(monkeypatch, capfd):
 | 
			
		||||
    mock_connection_data = {
 | 
			
		||||
    mirror = spack.mirror.Mirror.from_dict(
 | 
			
		||||
        {
 | 
			
		||||
            "fetch": {
 | 
			
		||||
                "access_token": "AAAAAAA",
 | 
			
		||||
                "profile": "SPacKDeV",
 | 
			
		||||
                "access_pair": ("SPA", "CK"),
 | 
			
		||||
                "endpoint_url": "https://127.0.0.1:8888",
 | 
			
		||||
            },
 | 
			
		||||
            "push": {
 | 
			
		||||
                "access_token": "AAAAAAA",
 | 
			
		||||
                "profile": "SPacKDeV",
 | 
			
		||||
                "access_pair": ("SPA", "CK"),
 | 
			
		||||
                "endpoint_url": "https://127.0.0.1:8888",
 | 
			
		||||
            },
 | 
			
		||||
        }
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    session_args, client_args = spack.util.s3.get_mirror_s3_connection_info(mock_connection_data)
 | 
			
		||||
    session_args, client_args = spack.util.s3.get_mirror_s3_connection_info(mirror, "push")
 | 
			
		||||
 | 
			
		||||
    # Session args are used to create the S3 Session object
 | 
			
		||||
    assert "aws_session_token" in session_args
 | 
			
		||||
@@ -273,10 +284,10 @@ def test_gather_s3_information(monkeypatch, capfd):
 | 
			
		||||
def test_remove_s3_url(monkeypatch, capfd):
 | 
			
		||||
    fake_s3_url = "s3://my-bucket/subdirectory/mirror"
 | 
			
		||||
 | 
			
		||||
    def mock_create_s3_session(url, connection={}):
 | 
			
		||||
    def get_s3_session(url, method="fetch"):
 | 
			
		||||
        return MockS3Client()
 | 
			
		||||
 | 
			
		||||
    monkeypatch.setattr(spack.util.s3, "create_s3_session", mock_create_s3_session)
 | 
			
		||||
    monkeypatch.setattr(spack.util.s3, "get_s3_session", get_s3_session)
 | 
			
		||||
 | 
			
		||||
    current_debug_level = tty.debug_level()
 | 
			
		||||
    tty.set_debug(1)
 | 
			
		||||
@@ -292,10 +303,10 @@ def mock_create_s3_session(url, connection={}):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_s3_url_exists(monkeypatch, capfd):
 | 
			
		||||
    def mock_create_s3_session(url, connection={}):
 | 
			
		||||
    def get_s3_session(url, method="fetch"):
 | 
			
		||||
        return MockS3Client()
 | 
			
		||||
 | 
			
		||||
    monkeypatch.setattr(spack.util.s3, "create_s3_session", mock_create_s3_session)
 | 
			
		||||
    monkeypatch.setattr(spack.util.s3, "get_s3_session", get_s3_session)
 | 
			
		||||
 | 
			
		||||
    fake_s3_url_exists = "s3://my-bucket/subdirectory/my-file"
 | 
			
		||||
    assert spack.util.web.url_exists(fake_s3_url_exists)
 | 
			
		||||
 
 | 
			
		||||
@@ -4,27 +4,75 @@
 | 
			
		||||
# SPDX-License-Identifier: (Apache-2.0 OR MIT)
 | 
			
		||||
import os
 | 
			
		||||
import urllib.parse
 | 
			
		||||
from typing import Any, Dict, Tuple
 | 
			
		||||
 | 
			
		||||
import spack
 | 
			
		||||
import spack.config
 | 
			
		||||
import spack.util.url as url_util
 | 
			
		||||
 | 
			
		||||
#: Map (mirror name, method) tuples to s3 client instances.
 | 
			
		||||
s3_client_cache: Dict[Tuple[str, str], Any] = dict()
 | 
			
		||||
 | 
			
		||||
def get_mirror_connection(url, url_type="push"):
 | 
			
		||||
    connection = {}
 | 
			
		||||
    # Try to find a mirror for potential connection information
 | 
			
		||||
    # Check to see if desired file starts with any of the mirror URLs
 | 
			
		||||
    rebuilt_path = url_util.format(url)
 | 
			
		||||
    # Gather dict of push URLS point to the value of the whole mirror
 | 
			
		||||
    mirror_dict = {x.push_url: x for x in spack.mirror.MirrorCollection().values()}
 | 
			
		||||
    # Ensure most specific URLs (longest) are presented first
 | 
			
		||||
    mirror_url_keys = mirror_dict.keys()
 | 
			
		||||
    mirror_url_keys = sorted(mirror_url_keys, key=len, reverse=True)
 | 
			
		||||
    for mURL in mirror_url_keys:
 | 
			
		||||
        # See if desired URL starts with the mirror's push URL
 | 
			
		||||
        if rebuilt_path.startswith(mURL):
 | 
			
		||||
            connection = mirror_dict[mURL].to_dict()[url_type]
 | 
			
		||||
            break
 | 
			
		||||
    return connection
 | 
			
		||||
 | 
			
		||||
def get_s3_session(url, method="fetch"):
 | 
			
		||||
    # import boto and friends as late as possible.  We don't want to require boto as a
 | 
			
		||||
    # dependency unless the user actually wants to access S3 mirrors.
 | 
			
		||||
    from boto3 import Session
 | 
			
		||||
    from botocore import UNSIGNED
 | 
			
		||||
    from botocore.client import Config
 | 
			
		||||
    from botocore.exceptions import ClientError
 | 
			
		||||
 | 
			
		||||
    # Circular dependency
 | 
			
		||||
    from spack.mirror import MirrorCollection
 | 
			
		||||
 | 
			
		||||
    global s3_client_cache
 | 
			
		||||
 | 
			
		||||
    # Get a (recycled) s3 session for a particular URL
 | 
			
		||||
    url = url_util.parse(url)
 | 
			
		||||
 | 
			
		||||
    url_str = url_util.format(url)
 | 
			
		||||
 | 
			
		||||
    def get_mirror_url(mirror):
 | 
			
		||||
        return mirror.fetch_url if method == "fetch" else mirror.push_url
 | 
			
		||||
 | 
			
		||||
    # Get all configured mirrors that could match.
 | 
			
		||||
    all_mirrors = MirrorCollection()
 | 
			
		||||
    mirrors = [
 | 
			
		||||
        (name, mirror)
 | 
			
		||||
        for name, mirror in all_mirrors.items()
 | 
			
		||||
        if url_str.startswith(get_mirror_url(mirror))
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
    if not mirrors:
 | 
			
		||||
        name, mirror = None, {}
 | 
			
		||||
    else:
 | 
			
		||||
        # In case we have more than one mirror, we pick the longest matching url.
 | 
			
		||||
        # The heuristic being that it's more specific, and you can have different
 | 
			
		||||
        # credentials for a sub-bucket (if that is a thing).
 | 
			
		||||
        name, mirror = max(
 | 
			
		||||
            mirrors, key=lambda name_and_mirror: len(get_mirror_url(name_and_mirror[1]))
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    key = (name, method)
 | 
			
		||||
 | 
			
		||||
    # Did we already create a client for this? Then return it.
 | 
			
		||||
    if key in s3_client_cache:
 | 
			
		||||
        return s3_client_cache[key]
 | 
			
		||||
 | 
			
		||||
    # Otherwise, create it.
 | 
			
		||||
    s3_connection, s3_client_args = get_mirror_s3_connection_info(mirror, method)
 | 
			
		||||
 | 
			
		||||
    session = Session(**s3_connection)
 | 
			
		||||
    # if no access credentials provided above, then access anonymously
 | 
			
		||||
    if not session.get_credentials():
 | 
			
		||||
        s3_client_args["config"] = Config(signature_version=UNSIGNED)
 | 
			
		||||
 | 
			
		||||
    client = session.client("s3", **s3_client_args)
 | 
			
		||||
    client.ClientError = ClientError
 | 
			
		||||
 | 
			
		||||
    # Cache the client.
 | 
			
		||||
    s3_client_cache[key] = client
 | 
			
		||||
    return client
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _parse_s3_endpoint_url(endpoint_url):
 | 
			
		||||
@@ -34,53 +82,37 @@ def _parse_s3_endpoint_url(endpoint_url):
 | 
			
		||||
    return endpoint_url
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_mirror_s3_connection_info(connection):
 | 
			
		||||
def get_mirror_s3_connection_info(mirror, method):
 | 
			
		||||
    """Create s3 config for session/client from a Mirror instance (or just set defaults
 | 
			
		||||
    when no mirror is given.)"""
 | 
			
		||||
    from spack.mirror import Mirror
 | 
			
		||||
 | 
			
		||||
    s3_connection = {}
 | 
			
		||||
 | 
			
		||||
    s3_connection_is_dict = connection and isinstance(connection, dict)
 | 
			
		||||
    if s3_connection_is_dict:
 | 
			
		||||
        if connection.get("access_token"):
 | 
			
		||||
            s3_connection["aws_session_token"] = connection["access_token"]
 | 
			
		||||
        if connection.get("access_pair"):
 | 
			
		||||
            s3_connection["aws_access_key_id"] = connection["access_pair"][0]
 | 
			
		||||
            s3_connection["aws_secret_access_key"] = connection["access_pair"][1]
 | 
			
		||||
        if connection.get("profile"):
 | 
			
		||||
            s3_connection["profile_name"] = connection["profile"]
 | 
			
		||||
 | 
			
		||||
    s3_client_args = {"use_ssl": spack.config.get("config:verify_ssl")}
 | 
			
		||||
 | 
			
		||||
    # access token
 | 
			
		||||
    if isinstance(mirror, Mirror):
 | 
			
		||||
        access_token = mirror.get_access_token(method)
 | 
			
		||||
        if access_token:
 | 
			
		||||
            s3_connection["aws_session_token"] = access_token
 | 
			
		||||
 | 
			
		||||
        # access pair
 | 
			
		||||
        access_pair = mirror.get_access_pair(method)
 | 
			
		||||
        if access_pair and access_pair[0] and access_pair[1]:
 | 
			
		||||
            s3_connection["aws_access_key_id"] = access_pair[0]
 | 
			
		||||
            s3_connection["aws_secret_access_key"] = access_pair[1]
 | 
			
		||||
 | 
			
		||||
        # profile
 | 
			
		||||
        profile = mirror.get_profile(method)
 | 
			
		||||
        if profile:
 | 
			
		||||
            s3_connection["profile_name"] = profile
 | 
			
		||||
 | 
			
		||||
        # endpoint url
 | 
			
		||||
        endpoint_url = mirror.get_endpoint_url(method) or os.environ.get("S3_ENDPOINT_URL")
 | 
			
		||||
    else:
 | 
			
		||||
        endpoint_url = os.environ.get("S3_ENDPOINT_URL")
 | 
			
		||||
 | 
			
		||||
    if endpoint_url:
 | 
			
		||||
        s3_client_args["endpoint_url"] = _parse_s3_endpoint_url(endpoint_url)
 | 
			
		||||
    elif s3_connection_is_dict and connection.get("endpoint_url"):
 | 
			
		||||
        s3_client_args["endpoint_url"] = _parse_s3_endpoint_url(connection["endpoint_url"])
 | 
			
		||||
 | 
			
		||||
    return (s3_connection, s3_client_args)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def create_s3_session(url, connection={}):
 | 
			
		||||
    url = url_util.parse(url)
 | 
			
		||||
    if url.scheme != "s3":
 | 
			
		||||
        raise ValueError(
 | 
			
		||||
            "Can not create S3 session from URL with scheme: {SCHEME}".format(SCHEME=url.scheme)
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    # NOTE(opadron): import boto and friends as late as possible.  We don't
 | 
			
		||||
    # want to require boto as a dependency unless the user actually wants to
 | 
			
		||||
    # access S3 mirrors.
 | 
			
		||||
    from boto3 import Session  # type: ignore[import]
 | 
			
		||||
    from botocore.exceptions import ClientError  # type: ignore[import]
 | 
			
		||||
 | 
			
		||||
    s3_connection, s3_client_args = get_mirror_s3_connection_info(connection)
 | 
			
		||||
 | 
			
		||||
    session = Session(**s3_connection)
 | 
			
		||||
    # if no access credentials provided above, then access anonymously
 | 
			
		||||
    if not session.get_credentials():
 | 
			
		||||
        from botocore import UNSIGNED  # type: ignore[import]
 | 
			
		||||
        from botocore.client import Config  # type: ignore[import]
 | 
			
		||||
 | 
			
		||||
        s3_client_args["config"] = Config(signature_version=UNSIGNED)
 | 
			
		||||
 | 
			
		||||
    client = session.client("s3", **s3_client_args)
 | 
			
		||||
    client.ClientError = ClientError
 | 
			
		||||
    return client
 | 
			
		||||
 
 | 
			
		||||
@@ -175,9 +175,7 @@ def push_to_url(local_file_path, remote_path, keep_original=True, extra_args=Non
 | 
			
		||||
        while remote_path.startswith("/"):
 | 
			
		||||
            remote_path = remote_path[1:]
 | 
			
		||||
 | 
			
		||||
        s3 = s3_util.create_s3_session(
 | 
			
		||||
            remote_url, connection=s3_util.get_mirror_connection(remote_url)
 | 
			
		||||
        )
 | 
			
		||||
        s3 = s3_util.get_s3_session(remote_url, method="push")
 | 
			
		||||
        s3.upload_file(local_file_path, remote_url.netloc, remote_path, ExtraArgs=extra_args)
 | 
			
		||||
 | 
			
		||||
        if not keep_original:
 | 
			
		||||
@@ -377,9 +375,7 @@ def url_exists(url, curl=None):
 | 
			
		||||
    # Check if Amazon Simple Storage Service (S3) .. urllib-based fetch
 | 
			
		||||
    if url_result.scheme == "s3":
 | 
			
		||||
        # Check for URL-specific connection information
 | 
			
		||||
        s3 = s3_util.create_s3_session(
 | 
			
		||||
            url_result, connection=s3_util.get_mirror_connection(url_result)
 | 
			
		||||
        )  # noqa: E501
 | 
			
		||||
        s3 = s3_util.get_s3_session(url_result, method="fetch")
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            s3.get_object(Bucket=url_result.netloc, Key=url_result.path.lstrip("/"))
 | 
			
		||||
@@ -441,7 +437,7 @@ def remove_url(url, recursive=False):
 | 
			
		||||
 | 
			
		||||
    if url.scheme == "s3":
 | 
			
		||||
        # Try to find a mirror for potential connection information
 | 
			
		||||
        s3 = s3_util.create_s3_session(url, connection=s3_util.get_mirror_connection(url))
 | 
			
		||||
        s3 = s3_util.get_s3_session(url, method="push")
 | 
			
		||||
        bucket = url.netloc
 | 
			
		||||
        if recursive:
 | 
			
		||||
            # Because list_objects_v2 can only return up to 1000 items
 | 
			
		||||
@@ -551,7 +547,7 @@ def list_url(url, recursive=False):
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
    if url.scheme == "s3":
 | 
			
		||||
        s3 = s3_util.create_s3_session(url, connection=s3_util.get_mirror_connection(url))
 | 
			
		||||
        s3 = s3_util.get_s3_session(url, method="fetch")
 | 
			
		||||
        if recursive:
 | 
			
		||||
            return list(_iter_s3_prefix(s3, url))
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user