s3: cache client instance (#34372)
This commit is contained in:
parent
d29cb87ecc
commit
7e054cb7fc
@ -44,7 +44,7 @@ def __getattr__(self, key):
|
|||||||
|
|
||||||
def _s3_open(url):
|
def _s3_open(url):
|
||||||
parsed = url_util.parse(url)
|
parsed = url_util.parse(url)
|
||||||
s3 = s3_util.create_s3_session(parsed, connection=s3_util.get_mirror_connection(parsed))
|
s3 = s3_util.get_s3_session(url, method="fetch")
|
||||||
|
|
||||||
bucket = parsed.netloc
|
bucket = parsed.netloc
|
||||||
key = parsed.path
|
key = parsed.path
|
||||||
|
@ -12,6 +12,7 @@
|
|||||||
import llnl.util.tty as tty
|
import llnl.util.tty as tty
|
||||||
|
|
||||||
import spack.config
|
import spack.config
|
||||||
|
import spack.mirror
|
||||||
import spack.paths
|
import spack.paths
|
||||||
import spack.util.s3
|
import spack.util.s3
|
||||||
import spack.util.web
|
import spack.util.web
|
||||||
@ -246,14 +247,24 @@ def get_object(self, Bucket=None, Key=None):
|
|||||||
|
|
||||||
|
|
||||||
def test_gather_s3_information(monkeypatch, capfd):
|
def test_gather_s3_information(monkeypatch, capfd):
|
||||||
mock_connection_data = {
|
mirror = spack.mirror.Mirror.from_dict(
|
||||||
"access_token": "AAAAAAA",
|
{
|
||||||
"profile": "SPacKDeV",
|
"fetch": {
|
||||||
"access_pair": ("SPA", "CK"),
|
"access_token": "AAAAAAA",
|
||||||
"endpoint_url": "https://127.0.0.1:8888",
|
"profile": "SPacKDeV",
|
||||||
}
|
"access_pair": ("SPA", "CK"),
|
||||||
|
"endpoint_url": "https://127.0.0.1:8888",
|
||||||
|
},
|
||||||
|
"push": {
|
||||||
|
"access_token": "AAAAAAA",
|
||||||
|
"profile": "SPacKDeV",
|
||||||
|
"access_pair": ("SPA", "CK"),
|
||||||
|
"endpoint_url": "https://127.0.0.1:8888",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
session_args, client_args = spack.util.s3.get_mirror_s3_connection_info(mock_connection_data)
|
session_args, client_args = spack.util.s3.get_mirror_s3_connection_info(mirror, "push")
|
||||||
|
|
||||||
# Session args are used to create the S3 Session object
|
# Session args are used to create the S3 Session object
|
||||||
assert "aws_session_token" in session_args
|
assert "aws_session_token" in session_args
|
||||||
@ -273,10 +284,10 @@ def test_gather_s3_information(monkeypatch, capfd):
|
|||||||
def test_remove_s3_url(monkeypatch, capfd):
|
def test_remove_s3_url(monkeypatch, capfd):
|
||||||
fake_s3_url = "s3://my-bucket/subdirectory/mirror"
|
fake_s3_url = "s3://my-bucket/subdirectory/mirror"
|
||||||
|
|
||||||
def mock_create_s3_session(url, connection={}):
|
def get_s3_session(url, method="fetch"):
|
||||||
return MockS3Client()
|
return MockS3Client()
|
||||||
|
|
||||||
monkeypatch.setattr(spack.util.s3, "create_s3_session", mock_create_s3_session)
|
monkeypatch.setattr(spack.util.s3, "get_s3_session", get_s3_session)
|
||||||
|
|
||||||
current_debug_level = tty.debug_level()
|
current_debug_level = tty.debug_level()
|
||||||
tty.set_debug(1)
|
tty.set_debug(1)
|
||||||
@ -292,10 +303,10 @@ def mock_create_s3_session(url, connection={}):
|
|||||||
|
|
||||||
|
|
||||||
def test_s3_url_exists(monkeypatch, capfd):
|
def test_s3_url_exists(monkeypatch, capfd):
|
||||||
def mock_create_s3_session(url, connection={}):
|
def get_s3_session(url, method="fetch"):
|
||||||
return MockS3Client()
|
return MockS3Client()
|
||||||
|
|
||||||
monkeypatch.setattr(spack.util.s3, "create_s3_session", mock_create_s3_session)
|
monkeypatch.setattr(spack.util.s3, "get_s3_session", get_s3_session)
|
||||||
|
|
||||||
fake_s3_url_exists = "s3://my-bucket/subdirectory/my-file"
|
fake_s3_url_exists = "s3://my-bucket/subdirectory/my-file"
|
||||||
assert spack.util.web.url_exists(fake_s3_url_exists)
|
assert spack.util.web.url_exists(fake_s3_url_exists)
|
||||||
|
@ -4,27 +4,75 @@
|
|||||||
# SPDX-License-Identifier: (Apache-2.0 OR MIT)
|
# SPDX-License-Identifier: (Apache-2.0 OR MIT)
|
||||||
import os
|
import os
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
|
from typing import Any, Dict, Tuple
|
||||||
|
|
||||||
import spack
|
import spack
|
||||||
|
import spack.config
|
||||||
import spack.util.url as url_util
|
import spack.util.url as url_util
|
||||||
|
|
||||||
|
#: Map (mirror name, method) tuples to s3 client instances.
|
||||||
|
s3_client_cache: Dict[Tuple[str, str], Any] = dict()
|
||||||
|
|
||||||
def get_mirror_connection(url, url_type="push"):
|
|
||||||
connection = {}
|
def get_s3_session(url, method="fetch"):
|
||||||
# Try to find a mirror for potential connection information
|
# import boto and friends as late as possible. We don't want to require boto as a
|
||||||
# Check to see if desired file starts with any of the mirror URLs
|
# dependency unless the user actually wants to access S3 mirrors.
|
||||||
rebuilt_path = url_util.format(url)
|
from boto3 import Session
|
||||||
# Gather dict of push URLS point to the value of the whole mirror
|
from botocore import UNSIGNED
|
||||||
mirror_dict = {x.push_url: x for x in spack.mirror.MirrorCollection().values()}
|
from botocore.client import Config
|
||||||
# Ensure most specific URLs (longest) are presented first
|
from botocore.exceptions import ClientError
|
||||||
mirror_url_keys = mirror_dict.keys()
|
|
||||||
mirror_url_keys = sorted(mirror_url_keys, key=len, reverse=True)
|
# Circular dependency
|
||||||
for mURL in mirror_url_keys:
|
from spack.mirror import MirrorCollection
|
||||||
# See if desired URL starts with the mirror's push URL
|
|
||||||
if rebuilt_path.startswith(mURL):
|
global s3_client_cache
|
||||||
connection = mirror_dict[mURL].to_dict()[url_type]
|
|
||||||
break
|
# Get a (recycled) s3 session for a particular URL
|
||||||
return connection
|
url = url_util.parse(url)
|
||||||
|
|
||||||
|
url_str = url_util.format(url)
|
||||||
|
|
||||||
|
def get_mirror_url(mirror):
|
||||||
|
return mirror.fetch_url if method == "fetch" else mirror.push_url
|
||||||
|
|
||||||
|
# Get all configured mirrors that could match.
|
||||||
|
all_mirrors = MirrorCollection()
|
||||||
|
mirrors = [
|
||||||
|
(name, mirror)
|
||||||
|
for name, mirror in all_mirrors.items()
|
||||||
|
if url_str.startswith(get_mirror_url(mirror))
|
||||||
|
]
|
||||||
|
|
||||||
|
if not mirrors:
|
||||||
|
name, mirror = None, {}
|
||||||
|
else:
|
||||||
|
# In case we have more than one mirror, we pick the longest matching url.
|
||||||
|
# The heuristic being that it's more specific, and you can have different
|
||||||
|
# credentials for a sub-bucket (if that is a thing).
|
||||||
|
name, mirror = max(
|
||||||
|
mirrors, key=lambda name_and_mirror: len(get_mirror_url(name_and_mirror[1]))
|
||||||
|
)
|
||||||
|
|
||||||
|
key = (name, method)
|
||||||
|
|
||||||
|
# Did we already create a client for this? Then return it.
|
||||||
|
if key in s3_client_cache:
|
||||||
|
return s3_client_cache[key]
|
||||||
|
|
||||||
|
# Otherwise, create it.
|
||||||
|
s3_connection, s3_client_args = get_mirror_s3_connection_info(mirror, method)
|
||||||
|
|
||||||
|
session = Session(**s3_connection)
|
||||||
|
# if no access credentials provided above, then access anonymously
|
||||||
|
if not session.get_credentials():
|
||||||
|
s3_client_args["config"] = Config(signature_version=UNSIGNED)
|
||||||
|
|
||||||
|
client = session.client("s3", **s3_client_args)
|
||||||
|
client.ClientError = ClientError
|
||||||
|
|
||||||
|
# Cache the client.
|
||||||
|
s3_client_cache[key] = client
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
def _parse_s3_endpoint_url(endpoint_url):
|
def _parse_s3_endpoint_url(endpoint_url):
|
||||||
@ -34,53 +82,37 @@ def _parse_s3_endpoint_url(endpoint_url):
|
|||||||
return endpoint_url
|
return endpoint_url
|
||||||
|
|
||||||
|
|
||||||
def get_mirror_s3_connection_info(connection):
|
def get_mirror_s3_connection_info(mirror, method):
|
||||||
|
"""Create s3 config for session/client from a Mirror instance (or just set defaults
|
||||||
|
when no mirror is given.)"""
|
||||||
|
from spack.mirror import Mirror
|
||||||
|
|
||||||
s3_connection = {}
|
s3_connection = {}
|
||||||
|
|
||||||
s3_connection_is_dict = connection and isinstance(connection, dict)
|
|
||||||
if s3_connection_is_dict:
|
|
||||||
if connection.get("access_token"):
|
|
||||||
s3_connection["aws_session_token"] = connection["access_token"]
|
|
||||||
if connection.get("access_pair"):
|
|
||||||
s3_connection["aws_access_key_id"] = connection["access_pair"][0]
|
|
||||||
s3_connection["aws_secret_access_key"] = connection["access_pair"][1]
|
|
||||||
if connection.get("profile"):
|
|
||||||
s3_connection["profile_name"] = connection["profile"]
|
|
||||||
|
|
||||||
s3_client_args = {"use_ssl": spack.config.get("config:verify_ssl")}
|
s3_client_args = {"use_ssl": spack.config.get("config:verify_ssl")}
|
||||||
|
|
||||||
endpoint_url = os.environ.get("S3_ENDPOINT_URL")
|
# access token
|
||||||
|
if isinstance(mirror, Mirror):
|
||||||
|
access_token = mirror.get_access_token(method)
|
||||||
|
if access_token:
|
||||||
|
s3_connection["aws_session_token"] = access_token
|
||||||
|
|
||||||
|
# access pair
|
||||||
|
access_pair = mirror.get_access_pair(method)
|
||||||
|
if access_pair and access_pair[0] and access_pair[1]:
|
||||||
|
s3_connection["aws_access_key_id"] = access_pair[0]
|
||||||
|
s3_connection["aws_secret_access_key"] = access_pair[1]
|
||||||
|
|
||||||
|
# profile
|
||||||
|
profile = mirror.get_profile(method)
|
||||||
|
if profile:
|
||||||
|
s3_connection["profile_name"] = profile
|
||||||
|
|
||||||
|
# endpoint url
|
||||||
|
endpoint_url = mirror.get_endpoint_url(method) or os.environ.get("S3_ENDPOINT_URL")
|
||||||
|
else:
|
||||||
|
endpoint_url = os.environ.get("S3_ENDPOINT_URL")
|
||||||
|
|
||||||
if endpoint_url:
|
if endpoint_url:
|
||||||
s3_client_args["endpoint_url"] = _parse_s3_endpoint_url(endpoint_url)
|
s3_client_args["endpoint_url"] = _parse_s3_endpoint_url(endpoint_url)
|
||||||
elif s3_connection_is_dict and connection.get("endpoint_url"):
|
|
||||||
s3_client_args["endpoint_url"] = _parse_s3_endpoint_url(connection["endpoint_url"])
|
|
||||||
|
|
||||||
return (s3_connection, s3_client_args)
|
return (s3_connection, s3_client_args)
|
||||||
|
|
||||||
|
|
||||||
def create_s3_session(url, connection={}):
|
|
||||||
url = url_util.parse(url)
|
|
||||||
if url.scheme != "s3":
|
|
||||||
raise ValueError(
|
|
||||||
"Can not create S3 session from URL with scheme: {SCHEME}".format(SCHEME=url.scheme)
|
|
||||||
)
|
|
||||||
|
|
||||||
# NOTE(opadron): import boto and friends as late as possible. We don't
|
|
||||||
# want to require boto as a dependency unless the user actually wants to
|
|
||||||
# access S3 mirrors.
|
|
||||||
from boto3 import Session # type: ignore[import]
|
|
||||||
from botocore.exceptions import ClientError # type: ignore[import]
|
|
||||||
|
|
||||||
s3_connection, s3_client_args = get_mirror_s3_connection_info(connection)
|
|
||||||
|
|
||||||
session = Session(**s3_connection)
|
|
||||||
# if no access credentials provided above, then access anonymously
|
|
||||||
if not session.get_credentials():
|
|
||||||
from botocore import UNSIGNED # type: ignore[import]
|
|
||||||
from botocore.client import Config # type: ignore[import]
|
|
||||||
|
|
||||||
s3_client_args["config"] = Config(signature_version=UNSIGNED)
|
|
||||||
|
|
||||||
client = session.client("s3", **s3_client_args)
|
|
||||||
client.ClientError = ClientError
|
|
||||||
return client
|
|
||||||
|
@ -175,9 +175,7 @@ def push_to_url(local_file_path, remote_path, keep_original=True, extra_args=Non
|
|||||||
while remote_path.startswith("/"):
|
while remote_path.startswith("/"):
|
||||||
remote_path = remote_path[1:]
|
remote_path = remote_path[1:]
|
||||||
|
|
||||||
s3 = s3_util.create_s3_session(
|
s3 = s3_util.get_s3_session(remote_url, method="push")
|
||||||
remote_url, connection=s3_util.get_mirror_connection(remote_url)
|
|
||||||
)
|
|
||||||
s3.upload_file(local_file_path, remote_url.netloc, remote_path, ExtraArgs=extra_args)
|
s3.upload_file(local_file_path, remote_url.netloc, remote_path, ExtraArgs=extra_args)
|
||||||
|
|
||||||
if not keep_original:
|
if not keep_original:
|
||||||
@ -377,9 +375,7 @@ def url_exists(url, curl=None):
|
|||||||
# Check if Amazon Simple Storage Service (S3) .. urllib-based fetch
|
# Check if Amazon Simple Storage Service (S3) .. urllib-based fetch
|
||||||
if url_result.scheme == "s3":
|
if url_result.scheme == "s3":
|
||||||
# Check for URL-specific connection information
|
# Check for URL-specific connection information
|
||||||
s3 = s3_util.create_s3_session(
|
s3 = s3_util.get_s3_session(url_result, method="fetch")
|
||||||
url_result, connection=s3_util.get_mirror_connection(url_result)
|
|
||||||
) # noqa: E501
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
s3.get_object(Bucket=url_result.netloc, Key=url_result.path.lstrip("/"))
|
s3.get_object(Bucket=url_result.netloc, Key=url_result.path.lstrip("/"))
|
||||||
@ -441,7 +437,7 @@ def remove_url(url, recursive=False):
|
|||||||
|
|
||||||
if url.scheme == "s3":
|
if url.scheme == "s3":
|
||||||
# Try to find a mirror for potential connection information
|
# Try to find a mirror for potential connection information
|
||||||
s3 = s3_util.create_s3_session(url, connection=s3_util.get_mirror_connection(url))
|
s3 = s3_util.get_s3_session(url, method="push")
|
||||||
bucket = url.netloc
|
bucket = url.netloc
|
||||||
if recursive:
|
if recursive:
|
||||||
# Because list_objects_v2 can only return up to 1000 items
|
# Because list_objects_v2 can only return up to 1000 items
|
||||||
@ -551,7 +547,7 @@ def list_url(url, recursive=False):
|
|||||||
]
|
]
|
||||||
|
|
||||||
if url.scheme == "s3":
|
if url.scheme == "s3":
|
||||||
s3 = s3_util.create_s3_session(url, connection=s3_util.get_mirror_connection(url))
|
s3 = s3_util.get_s3_session(url, method="fetch")
|
||||||
if recursive:
|
if recursive:
|
||||||
return list(_iter_s3_prefix(s3, url))
|
return list(_iter_s3_prefix(s3, url))
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user