fix zip downloads (#5009)

This commit is contained in:
Robert Brennan 2024-11-14 17:17:36 -05:00 committed by GitHub
parent be92965209
commit f3b35663e9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 87 additions and 91 deletions

View File

@ -432,12 +432,13 @@ class EventStreamRuntime(Runtime):
if not self.log_buffer:
raise RuntimeError('Runtime client is not ready.')
send_request(
with send_request(
self.session,
'GET',
f'{self.api_url}/alive',
timeout=5,
)
):
pass
def close(self, rm_all_containers: bool = True):
"""Closes the EventStreamRuntime and associated objects
@ -496,17 +497,17 @@ class EventStreamRuntime(Runtime):
assert action.timeout is not None
try:
response = send_request(
with send_request(
self.session,
'POST',
f'{self.api_url}/execute_action',
json={'action': event_to_dict(action)},
# wait a few more seconds to get the timeout error from client side
timeout=action.timeout + 5,
)
output = response.json()
obs = observation_from_dict(output)
obs._cause = action.id # type: ignore[attr-defined]
) as response:
output = response.json()
obs = observation_from_dict(output)
obs._cause = action.id # type: ignore[attr-defined]
except requests.Timeout:
raise RuntimeError(
f'Runtime failed to return execute_action before the requested timeout of {action.timeout}s'
@ -567,14 +568,15 @@ class EventStreamRuntime(Runtime):
params = {'destination': sandbox_dest, 'recursive': str(recursive).lower()}
send_request(
with send_request(
self.session,
'POST',
f'{self.api_url}/upload_file',
files=upload_data,
params=params,
timeout=300,
)
):
pass
except requests.Timeout:
raise TimeoutError('Copy operation timed out')
@ -599,16 +601,16 @@ class EventStreamRuntime(Runtime):
if path is not None:
data['path'] = path
response = send_request(
with send_request(
self.session,
'POST',
f'{self.api_url}/list_files',
json=data,
timeout=10,
)
response_json = response.json()
assert isinstance(response_json, list)
return response_json
) as response:
response_json = response.json()
assert isinstance(response_json, list)
return response_json
except requests.Timeout:
raise TimeoutError('List files operation timed out')
@ -617,19 +619,19 @@ class EventStreamRuntime(Runtime):
self._refresh_logs()
try:
params = {'path': path}
response = send_request(
with send_request(
self.session,
'GET',
f'{self.api_url}/download_files',
params=params,
stream=True,
timeout=30,
)
temp_file = tempfile.NamedTemporaryFile(delete=False)
for chunk in response.iter_content(chunk_size=8192):
if chunk: # filter out keep-alive new chunks
temp_file.write(chunk)
return Path(temp_file.name)
) as response:
temp_file = tempfile.NamedTemporaryFile(delete=False)
for chunk in response.iter_content(chunk_size=8192):
if chunk: # filter out keep-alive new chunks
temp_file.write(chunk)
return Path(temp_file.name)
except requests.Timeout:
raise TimeoutError('Copy operation timed out')
@ -658,21 +660,21 @@ class EventStreamRuntime(Runtime):
): # cached value
return self._vscode_url
response = send_request(
with send_request(
self.session,
'GET',
f'{self.api_url}/vscode/connection_token',
timeout=10,
)
response_json = response.json()
assert isinstance(response_json, dict)
if response_json['token'] is None:
return None
self._vscode_url = f'http://localhost:{self._host_port + 1}/?tkn={response_json["token"]}&folder={self.config.workspace_mount_path_in_sandbox}'
self.log(
'debug',
f'VSCode URL: {self._vscode_url}',
)
return self._vscode_url
) as response:
response_json = response.json()
assert isinstance(response_json, dict)
if response_json['token'] is None:
return None
self._vscode_url = f'http://localhost:{self._host_port + 1}/?tkn={response_json["token"]}&folder={self.config.workspace_mount_path_in_sandbox}'
self.log(
'debug',
f'VSCode URL: {self._vscode_url}',
)
return self._vscode_url
else:
return None

View File

@ -141,29 +141,29 @@ class RemoteRuntime(Runtime):
def _check_existing_runtime(self) -> bool:
try:
response = self._send_request(
with self._send_request(
'GET',
f'{self.config.sandbox.remote_runtime_api_url}/sessions/{self.sid}',
is_retry=False,
timeout=5,
)
) as response:
data = response.json()
status = data.get('status')
if status == 'running' or status == 'paused':
self._parse_runtime_response(response)
except requests.HTTPError as e:
if e.response.status_code == 404:
return False
self.log('debug', f'Error while looking for remote runtime: {e}')
raise
data = response.json()
status = data.get('status')
if status == 'running':
self._parse_runtime_response(response)
return True
elif status == 'stopped':
self.log('debug', 'Found existing remote runtime, but it is stopped')
return False
elif status == 'paused':
self.log('debug', 'Found existing remote runtime, but it is paused')
self._parse_runtime_response(response)
self._resume_runtime()
return True
else:
@ -172,13 +172,13 @@ class RemoteRuntime(Runtime):
def _build_runtime(self):
self.log('debug', f'Building RemoteRuntime config:\n{self.config}')
response = self._send_request(
with self._send_request(
'GET',
f'{self.config.sandbox.remote_runtime_api_url}/registry_prefix',
is_retry=False,
timeout=10,
)
response_json = response.json()
) as response:
response_json = response.json()
registry_prefix = response_json['registry_prefix']
os.environ['OH_RUNTIME_RUNTIME_IMAGE_REPO'] = (
registry_prefix.rstrip('/') + '/runtime'
@ -203,15 +203,17 @@ class RemoteRuntime(Runtime):
force_rebuild=self.config.sandbox.force_rebuild_runtime,
)
response = self._send_request(
with self._send_request(
'GET',
f'{self.config.sandbox.remote_runtime_api_url}/image_exists',
is_retry=False,
params={'image': self.container_image},
timeout=10,
)
if not response.json()['exists']:
raise RuntimeError(f'Container image {self.container_image} does not exist')
) as response:
if not response.json()['exists']:
raise RuntimeError(
f'Container image {self.container_image} does not exist'
)
def _start_runtime(self):
# Prepare the request body for the /start endpoint
@ -240,26 +242,27 @@ class RemoteRuntime(Runtime):
}
# Start the sandbox using the /start endpoint
response = self._send_request(
with self._send_request(
'POST',
f'{self.config.sandbox.remote_runtime_api_url}/start',
is_retry=False,
json=start_request,
)
self._parse_runtime_response(response)
) as response:
self._parse_runtime_response(response)
self.log(
'debug',
f'Runtime started. URL: {self.runtime_url}',
)
def _resume_runtime(self):
self._send_request(
with self._send_request(
'POST',
f'{self.config.sandbox.remote_runtime_api_url}/resume',
is_retry=False,
json={'runtime_id': self.runtime_id},
timeout=30,
)
):
pass
self.log('debug', 'Runtime resumed.')
def _parse_runtime_response(self, response: requests.Response):
@ -279,12 +282,12 @@ class RemoteRuntime(Runtime):
): # cached value
return self._vscode_url
response = self._send_request(
with self._send_request(
'GET',
f'{self.runtime_url}/vscode/connection_token',
timeout=10,
)
response_json = response.json()
) as response:
response_json = response.json()
assert isinstance(response_json, dict)
if response_json['token'] is None:
return None
@ -316,11 +319,11 @@ class RemoteRuntime(Runtime):
def _wait_until_alive_impl(self):
self.log('debug', f'Waiting for runtime to be alive at url: {self.runtime_url}')
runtime_info_response = self._send_request(
with self._send_request(
'GET',
f'{self.config.sandbox.remote_runtime_api_url}/sessions/{self.sid}',
)
runtime_data = runtime_info_response.json()
) as runtime_info_response:
runtime_data = runtime_info_response.json()
assert 'runtime_id' in runtime_data
assert runtime_data['runtime_id'] == self.runtime_id
assert 'pod_status' in runtime_data
@ -332,10 +335,11 @@ class RemoteRuntime(Runtime):
# Retry a period of time to give the cluster time to start the pod
if pod_status == 'Ready':
try:
self._send_request(
with self._send_request(
'GET',
f'{self.runtime_url}/alive',
) # will raise exception if we don't get 200 back.
): # will raise exception if we don't get 200 back.
pass
except requests.HTTPError as e:
self.log(
'warning', f"Runtime /alive failed, but pod says it's ready: {e}"
@ -374,19 +378,13 @@ class RemoteRuntime(Runtime):
return
if self.runtime_id and self.session:
try:
response = self._send_request(
with self._send_request(
'POST',
f'{self.config.sandbox.remote_runtime_api_url}/stop',
is_retry=False,
json={'runtime_id': self.runtime_id},
timeout=timeout,
)
if response.status_code != 200:
self.log(
'error',
f'Failed to stop runtime: {response.text}',
)
else:
):
self.log('debug', 'Runtime stopped.')
except Exception as e:
raise e
@ -415,15 +413,15 @@ class RemoteRuntime(Runtime):
try:
request_body = {'action': event_to_dict(action)}
self.log('debug', f'Request body: {request_body}')
response = self._send_request(
with self._send_request(
'POST',
f'{self.runtime_url}/execute_action',
is_retry=False,
json=request_body,
# wait a few more seconds to get the timeout error from client side
timeout=action.timeout + 5,
)
output = response.json()
) as response:
output = response.json()
obs = observation_from_dict(output)
obs._cause = action.id # type: ignore[attr-defined]
except requests.Timeout:
@ -502,18 +500,18 @@ class RemoteRuntime(Runtime):
params = {'destination': sandbox_dest, 'recursive': str(recursive).lower()}
response = self._send_request(
with self._send_request(
'POST',
f'{self.runtime_url}/upload_file',
is_retry=False,
files=upload_data,
params=params,
timeout=300,
)
self.log(
'debug',
f'Copy completed: host:{host_src} -> runtime:{sandbox_dest}. Response: {response.text}',
)
) as response:
self.log(
'debug',
f'Copy completed: host:{host_src} -> runtime:{sandbox_dest}. Response: {response.text}',
)
finally:
if recursive:
os.unlink(temp_zip_path)
@ -526,30 +524,30 @@ class RemoteRuntime(Runtime):
if path is not None:
data['path'] = path
response = self._send_request(
with self._send_request(
'POST',
f'{self.runtime_url}/list_files',
is_retry=False,
json=data,
timeout=30,
)
response_json = response.json()
) as response:
response_json = response.json()
assert isinstance(response_json, list)
return response_json
def copy_from(self, path: str) -> Path:
"""Zip all files in the sandbox and return as a stream of bytes."""
params = {'path': path}
response = self._send_request(
with self._send_request(
'GET',
f'{self.runtime_url}/download_files',
is_retry=False,
params=params,
stream=True,
timeout=30,
)
temp_file = tempfile.NamedTemporaryFile(delete=False)
for chunk in response.iter_content(chunk_size=8192):
if chunk: # filter out keep-alive new chunks
temp_file.write(chunk)
return Path(temp_file.name)
) as response:
temp_file = tempfile.NamedTemporaryFile(delete=False)
for chunk in response.iter_content(chunk_size=8192):
if chunk: # filter out keep-alive new chunks
temp_file.write(chunk)
return Path(temp_file.name)

View File

@ -58,9 +58,5 @@ def send_request(
**kwargs: Any,
) -> requests.Response:
response = session.request(method, url, **kwargs)
try:
response.raise_for_status()
finally:
response.close()
response.raise_for_status()
return response