Improve SensitiveDataFilter and add comprehensive tests (#6755)

Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
tofarr 2025-02-17 17:23:43 +00:00 committed by GitHub
parent ae31a24c29
commit f4b123f73b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 155 additions and 39 deletions

View File

@ -217,7 +217,21 @@ class RollingLogger:
class SensitiveDataFilter(logging.Filter):
def filter(self, record):
# start with attributes
# Gather sensitive values which should not ever appear in the logs.
sensitive_values = []
for key, value in os.environ.items():
key_upper = key.upper()
if len(value) > 2 and any(
s in key_upper for s in ('SECRET', 'KEY', 'CODE', 'TOKEN')
):
sensitive_values.append(value)
# Replace sensitive values from env!
msg = record.getMessage()
for sensitive_value in sensitive_values:
msg = msg.replace(sensitive_value, '******')
# Replace obvious sensitive values from log itself...
sensitive_patterns = [
'api_key',
'aws_access_key_id',
@ -227,28 +241,22 @@ class SensitiveDataFilter(logging.Filter):
'jwt_secret',
'modal_api_token_id',
'modal_api_token_secret',
'llm_api_key',
'sandbox_env_github_token',
]
# add env var names
env_vars = [attr.upper() for attr in sensitive_patterns]
sensitive_patterns.extend(env_vars)
# and some special cases
sensitive_patterns.append('JWT_SECRET')
sensitive_patterns.append('LLM_API_KEY')
sensitive_patterns.append('GITHUB_TOKEN')
sensitive_patterns.append('SANDBOX_ENV_GITHUB_TOKEN')
# this also formats the message with % args
msg = record.getMessage()
record.args = ()
for attr in sensitive_patterns:
pattern = rf"{attr}='?([\w-]+)'?"
msg = re.sub(pattern, f"{attr}='******'", msg)
# passed with msg
# Update the record
record.msg = msg
record.args = ()
return True

117
tests/unit/test_logger.py Normal file
View File

@ -0,0 +1,117 @@
import logging
from unittest.mock import patch
from openhands.core.logger import SensitiveDataFilter
@patch.dict(
'os.environ',
{
'API_SECRET': 'super-secret-123',
'AUTH_TOKEN': 'auth-token-456',
'NORMAL_VAR': 'normal-value',
},
clear=True,
)
def test_sensitive_data_filter_basic():
# Create a filter instance
filter = SensitiveDataFilter()
# Create a log record with sensitive data
record = logging.LogRecord(
name='test_logger',
level=logging.INFO,
pathname='test.py',
lineno=1,
msg='API Secret: super-secret-123, Token: auth-token-456, Normal: normal-value',
args=(),
exc_info=None,
)
# Apply the filter
filter.filter(record)
# Check that sensitive data is masked but normal data isn't
assert '******' in record.msg
assert 'super-secret-123' not in record.msg
assert 'auth-token-456' not in record.msg
assert 'normal-value' in record.msg
@patch.dict('os.environ', {}, clear=True)
def test_sensitive_data_filter_empty_values():
# Test with empty environment variables
filter = SensitiveDataFilter()
record = logging.LogRecord(
name='test_logger',
level=logging.INFO,
pathname='test.py',
lineno=1,
msg='No sensitive data here',
args=(),
exc_info=None,
)
# Apply the filter
filter.filter(record)
# Message should remain unchanged
assert record.msg == 'No sensitive data here'
@patch.dict('os.environ', {'API_KEY': 'secret-key-789'}, clear=True)
def test_sensitive_data_filter_multiple_occurrences():
# Test with multiple occurrences of the same sensitive data
filter = SensitiveDataFilter()
# Create a message with multiple occurrences of the same sensitive data
record = logging.LogRecord(
name='test_logger',
level=logging.INFO,
pathname='test.py',
lineno=1,
msg='Key1: secret-key-789, Key2: secret-key-789',
args=(),
exc_info=None,
)
# Apply the filter
filter.filter(record)
# Check that all occurrences are masked
assert record.msg.count('******') == 2
assert 'secret-key-789' not in record.msg
@patch.dict(
'os.environ',
{
'secret_KEY': 'secret-value-1',
'API_secret': 'secret-value-2',
'TOKEN_code': 'secret-value-3',
},
clear=True,
)
def test_sensitive_data_filter_case_sensitivity():
# Test with different case variations in environment variable names
filter = SensitiveDataFilter()
record = logging.LogRecord(
name='test_logger',
level=logging.INFO,
pathname='test.py',
lineno=1,
msg='Values: secret-value-1, secret-value-2, secret-value-3',
args=(),
exc_info=None,
)
# Apply the filter
filter.filter(record)
# Check that all sensitive values are masked regardless of case
assert 'secret-value-1' not in record.msg
assert 'secret-value-2' not in record.msg
assert 'secret-value-3' not in record.msg
assert record.msg.count('******') == 3

View File

@ -1,5 +1,6 @@
import logging
from io import StringIO
from unittest.mock import patch
import pytest
@ -26,7 +27,6 @@ def test_openai_api_key_masking(test_handler):
message = f"OpenAI API key: api_key='{api_key}'and there's some stuff here"
logger.info(message)
log_output = stream.getvalue()
assert "api_key='******'" in log_output
assert api_key not in log_output
@ -36,7 +36,6 @@ def test_azure_api_key_masking(test_handler):
message = f"Azure API key: api_key='{api_key}' and chatty chat with ' and \" and '"
logger.info(message)
log_output = stream.getvalue()
assert "api_key='******'" in log_output
assert api_key not in log_output
@ -46,7 +45,6 @@ def test_google_vertex_api_key_masking(test_handler):
message = f"Google Vertex API key: api_key='{api_key}' or not"
logger.info(message)
log_output = stream.getvalue()
assert "api_key='******'" in log_output
assert api_key not in log_output
@ -56,7 +54,6 @@ def test_anthropic_api_key_masking(test_handler):
message = f"Anthropic API key: api_key='{api_key}' and there's some 'stuff' here"
logger.info(message)
log_output = stream.getvalue()
assert "api_key='******'" in log_output
assert api_key not in log_output
@ -69,9 +66,6 @@ def test_llm_config_attributes_masking(test_handler):
)
logger.info(f'LLM Config: {llm_config}')
log_output = stream.getvalue()
assert "api_key='******'" in log_output
assert "aws_access_key_id='******'" in log_output
assert "aws_secret_access_key='******'" in log_output
assert 'sk-abc123' not in log_output
assert 'AKIAIOSFODNN7EXAMPLE' not in log_output
assert 'wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY' not in log_output
@ -82,7 +76,6 @@ def test_app_config_attributes_masking(test_handler):
app_config = AppConfig(e2b_api_key='e2b-xyz789')
logger.info(f'App Config: {app_config}')
log_output = stream.getvalue()
assert "e2b_api_key='******'" in log_output
assert 'github_token' not in log_output
assert 'e2b-xyz789' not in log_output
assert 'ghp_abcdefghijklmnopqrstuvwxyz' not in log_output
@ -90,7 +83,7 @@ def test_app_config_attributes_masking(test_handler):
def test_sensitive_env_vars_masking(test_handler):
logger, stream = test_handler
sensitive_data = {
environ = {
'API_KEY': 'API_KEY_VALUE',
'AWS_ACCESS_KEY_ID': 'AWS_ACCESS_KEY_ID_VALUE',
'AWS_SECRET_ACCESS_KEY': 'AWS_SECRET_ACCESS_KEY_VALUE',
@ -99,31 +92,29 @@ def test_sensitive_env_vars_masking(test_handler):
'JWT_SECRET': 'JWT_SECRET_VALUE',
}
log_message = ' '.join(
f"{attr}='{value}'" for attr, value in sensitive_data.items()
)
logger.info(log_message)
with patch.dict('openhands.core.logger.os.environ', environ, clear=True):
log_message = ' '.join(f"{attr}='{value}'" for attr, value in environ.items())
logger.info(log_message)
log_output = stream.getvalue()
for attr, value in sensitive_data.items():
assert f"{attr}='******'" in log_output
assert value not in log_output
log_output = stream.getvalue()
for _, value in environ.items():
assert value not in log_output
def test_special_cases_masking(test_handler):
logger, stream = test_handler
sensitive_data = {
environ = {
'LLM_API_KEY': 'LLM_API_KEY_VALUE',
'SANDBOX_ENV_GITHUB_TOKEN': 'SANDBOX_ENV_GITHUB_TOKEN_VALUE',
}
log_message = ' '.join(
f"{attr}={value} with no single quotes' and something"
for attr, value in sensitive_data.items()
)
logger.info(log_message)
with patch.dict('openhands.core.logger.os.environ', environ, clear=True):
log_message = ' '.join(
f"{attr}={value} with no single quotes' and something"
for attr, value in environ.items()
)
logger.info(log_message)
log_output = stream.getvalue()
for attr, value in sensitive_data.items():
assert f"{attr}='******'" in log_output
assert value not in log_output
log_output = stream.getvalue()
for attr, value in environ.items():
assert value not in log_output