fix: Handle file URI conversion for LiteLLM based on provider and model

This change updates how `file_data.file_uri` parts are converted to LiteLLM content. For providers like OpenAI and Azure, only URIs resembling OpenAI file IDs ("file-...") are passed as file objects. Other URIs are converted to a text placeholder

Close #4038

Co-authored-by: George Weale <gweale@google.com>
PiperOrigin-RevId: 855277306
This commit is contained in:
George Weale
2026-01-12 10:10:17 -08:00
committed by Copybara-Service
parent 94d48fce32
commit 43b484ff66
2 changed files with 195 additions and 3 deletions
+75 -3
View File
@@ -181,6 +181,45 @@ def _infer_mime_type_from_uri(uri: str) -> Optional[str]:
return None
def _looks_like_openai_file_id(file_uri: str) -> bool:
"""Returns True when file_uri resembles an OpenAI/Azure file id."""
return file_uri.startswith("file-")
def _redact_file_uri_for_log(
file_uri: str, *, display_name: str | None = None
) -> str:
"""Returns a privacy-preserving identifier for logs."""
if display_name:
return display_name
if _looks_like_openai_file_id(file_uri):
return "file-<redacted>"
try:
parsed = urlparse(file_uri)
except ValueError:
return "<unparseable>"
if not parsed.scheme:
return "<unknown>"
segments = [segment for segment in parsed.path.split("/") if segment]
tail = segments[-1] if segments else ""
if tail:
return f"{parsed.scheme}://<redacted>/{tail}"
return f"{parsed.scheme}://<redacted>"
def _requires_file_uri_fallback(
provider: str, model: str, file_uri: str
) -> bool:
"""Returns True when `file_uri` should not be sent as a file content block."""
if provider in _FILE_ID_REQUIRED_PROVIDERS:
return not _looks_like_openai_file_id(file_uri)
if provider == "anthropic":
return True
if provider == "vertex_ai" and not _is_litellm_gemini_model(model):
return True
return False
def _decode_inline_text_data(raw_bytes: bytes) -> str:
"""Decodes inline file bytes that represent textual content."""
try:
@@ -447,6 +486,7 @@ async def _content_to_message_param(
content: types.Content,
*,
provider: str = "",
model: str = "",
) -> Union[Message, list[Message]]:
"""Converts a types.Content to a litellm Message or list of Messages.
@@ -456,6 +496,7 @@ async def _content_to_message_param(
Args:
content: The content to convert.
provider: The LLM provider name (e.g., "openai", "azure").
model: The LiteLLM model string, used for provider-specific behavior.
Returns:
A litellm Message, a list of litellm Messages.
@@ -499,7 +540,9 @@ async def _content_to_message_param(
if role == "user":
user_parts = [part for part in content.parts if not part.thought]
message_content = await _get_content(user_parts, provider=provider) or None
message_content = (
await _get_content(user_parts, provider=provider, model=model) or None
)
return ChatCompletionUserMessage(role="user", content=message_content)
else: # assistant/model
tool_calls = []
@@ -523,7 +566,7 @@ async def _content_to_message_param(
content_parts.append(part)
final_content = (
await _get_content(content_parts, provider=provider)
await _get_content(content_parts, provider=provider, model=model)
if content_parts
else None
)
@@ -620,6 +663,7 @@ async def _get_content(
parts: Iterable[types.Part],
*,
provider: str = "",
model: str = "",
) -> OpenAIMessageContent:
"""Converts a list of parts to litellm content.
@@ -629,6 +673,8 @@ async def _get_content(
Args:
parts: The parts to convert.
provider: The LLM provider name (e.g., "openai", "azure").
model: The LiteLLM model string (e.g., "openai/gpt-4o",
"vertex_ai/gemini-2.5-flash").
Returns:
The litellm content.
@@ -709,6 +755,32 @@ async def _get_content(
f"{part.inline_data.mime_type}."
)
elif part.file_data and part.file_data.file_uri:
if (
provider in _FILE_ID_REQUIRED_PROVIDERS
and _looks_like_openai_file_id(part.file_data.file_uri)
):
content_objects.append({
"type": "file",
"file": {"file_id": part.file_data.file_uri},
})
continue
if _requires_file_uri_fallback(provider, model, part.file_data.file_uri):
logger.debug(
"File URI %s not supported for provider %s, using text fallback",
_redact_file_uri_for_log(
part.file_data.file_uri,
display_name=part.file_data.display_name,
),
provider,
)
identifier = part.file_data.display_name or part.file_data.file_uri
content_objects.append({
"type": "text",
"text": f'[File reference: "{identifier}"]',
})
continue
file_object: ChatCompletionFileUrlObject = {
"file_id": part.file_data.file_uri,
}
@@ -1363,7 +1435,7 @@ async def _get_completion_inputs(
messages: List[Message] = []
for content in llm_request.contents or []:
message_param_or_list = await _content_to_message_param(
content, provider=provider
content, provider=provider, model=model
)
if isinstance(message_param_or_list, list):
messages.extend(message_param_or_list)
+120
View File
@@ -2304,6 +2304,126 @@ async def test_get_content_file_uri(file_uri, mime_type):
}
@pytest.mark.asyncio
@pytest.mark.parametrize(
"provider,model",
[
("openai", "openai/gpt-4o"),
("azure", "azure/gpt-4"),
],
)
async def test_get_content_file_uri_file_id_required_falls_back_to_text(
provider, model
):
parts = [
types.Part(
file_data=types.FileData(
file_uri="gs://bucket/path/to/document.pdf",
mime_type="application/pdf",
display_name="document.pdf",
)
)
]
content = await _get_content(parts, provider=provider, model=model)
assert content == [
{"type": "text", "text": '[File reference: "document.pdf"]'}
]
@pytest.mark.asyncio
@pytest.mark.parametrize(
"provider,model",
[
("openai", "openai/gpt-4o"),
("azure", "azure/gpt-4"),
],
)
async def test_get_content_file_uri_file_id_required_preserves_file_id(
provider, model
):
parts = [
types.Part(
file_data=types.FileData(
file_uri="file-abc123",
mime_type="application/pdf",
)
)
]
content = await _get_content(parts, provider=provider, model=model)
assert content == [{"type": "file", "file": {"file_id": "file-abc123"}}]
@pytest.mark.asyncio
async def test_get_content_file_uri_anthropic_falls_back_to_text():
parts = [
types.Part(
file_data=types.FileData(
file_uri="gs://bucket/path/to/document.pdf",
mime_type="application/pdf",
display_name="document.pdf",
)
)
]
content = await _get_content(
parts, provider="anthropic", model="anthropic/claude-3-5"
)
assert content == [
{"type": "text", "text": '[File reference: "document.pdf"]'}
]
@pytest.mark.asyncio
async def test_get_content_file_uri_anthropic_openai_file_id_falls_back_to_text():
parts = [types.Part(file_data=types.FileData(file_uri="file-abc123"))]
content = await _get_content(
parts, provider="anthropic", model="anthropic/claude-3-5"
)
assert content == [
{"type": "text", "text": '[File reference: "file-abc123"]'}
]
@pytest.mark.asyncio
async def test_get_content_file_uri_vertex_ai_non_gemini_falls_back_to_text():
parts = [
types.Part(
file_data=types.FileData(
file_uri="gs://bucket/path/to/document.pdf",
mime_type="application/pdf",
display_name="document.pdf",
)
)
]
content = await _get_content(
parts, provider="vertex_ai", model="vertex_ai/claude-3-5"
)
assert content == [
{"type": "text", "text": '[File reference: "document.pdf"]'}
]
@pytest.mark.asyncio
async def test_get_content_file_uri_vertex_ai_gemini_keeps_file_block():
parts = [
types.Part(
file_data=types.FileData(
file_uri="gs://bucket/path/to/document.pdf",
mime_type="application/pdf",
)
)
]
content = await _get_content(
parts, provider="vertex_ai", model="vertex_ai/gemini-2.5-flash"
)
assert content == [{
"type": "file",
"file": {
"file_id": "gs://bucket/path/to/document.pdf",
"format": "application/pdf",
},
}]
@pytest.mark.asyncio
async def test_get_content_file_uri_infer_mime_type():
"""Test MIME type inference from file_uri extension.