Skip to content

Commit

Permalink
add input url and http_proxy (#475)
Browse files Browse the repository at this point in the history
  • Loading branch information
hexapode authored Nov 12, 2024
1 parent 89348aa commit eeabf48
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 16 deletions.
45 changes: 37 additions & 8 deletions llama_parse/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import asyncio
from io import TextIOWrapper
from urllib.parse import urlparse

import httpx
import mimetypes
Expand All @@ -11,7 +11,6 @@
from io import BufferedIOBase

from fsspec import AbstractFileSystem
from fsspec.spec import AbstractBufferedFile
from llama_index.core.async_utils import asyncio_run, run_jobs
from llama_index.core.bridge.pydantic import Field, field_validator
from llama_index.core.constants import DEFAULT_BASE_URL
Expand Down Expand Up @@ -190,6 +189,10 @@ class LlamaParse(BasePydanticReader):
azure_openai_key: Optional[str] = Field(
default=None, description="Azure Openai Key"
)
http_proxy: Optional[str] = Field(
default=None,
description="(optional) If set with input_url will use the specified http proxy to download the file.",
)

@field_validator("api_key", mode="before", check_fields=True)
@classmethod
Expand Down Expand Up @@ -221,6 +224,28 @@ async def client_context(self) -> AsyncGenerator[httpx.AsyncClient, None]:
async with httpx.AsyncClient(timeout=self.max_timeout) as client:
yield client

def _is_input_url(self, file_path: FileInput) -> bool:
"""Check if the input is a valid URL.
This method checks for:
- Proper URL scheme (http/https)
- Valid URL structure
- Network location (domain)
"""
if not isinstance(file_path, str):
return False
try:
result = urlparse(file_path)
return all(
[
result.scheme in ("http", "https"),
result.netloc, # Has domain
result.scheme, # Has scheme
]
)
except Exception:
return False

# upload a document and get back a job_id
async def _create_job(
self,
Expand All @@ -232,6 +257,7 @@ async def _create_job(
url = f"{self.base_url}/api/parsing/upload"
files = None
file_handle = None
input_url = file_input if self._is_input_url(file_input) else None

if isinstance(file_input, (bytes, BufferedIOBase)):
if not extra_info or "file_name" not in extra_info:
Expand All @@ -241,6 +267,8 @@ async def _create_job(
file_name = extra_info["file_name"]
mime_type = mimetypes.guess_type(file_name)[0]
files = {"file": (file_name, file_input, mime_type)}
elif input_url is not None:
files = None
elif isinstance(file_input, (str, Path, PurePosixPath, PurePath)):
file_path = str(file_input)
file_ext = os.path.splitext(file_path)[1].lower()
Expand Down Expand Up @@ -316,6 +344,13 @@ async def _create_job(
if self.azure_openai_key is not None:
data["azure_openai_key"] = self.azure_openai_key

if input_url is not None:
files = None
data["input_url"] = str(input_url)

if self.http_proxy is not None:
data["http_proxy"] = self.http_proxy

try:
async with self.client_context() as client:
response = await client.post(
Expand All @@ -332,12 +367,6 @@ async def _create_job(
if file_handle is not None:
file_handle.close()

@staticmethod
def __get_filename(f: Union[TextIOWrapper, AbstractBufferedFile]) -> str:
if isinstance(f, TextIOWrapper):
return f.name
return f.full_name

async def _get_job_result(
self, job_id: str, result_type: str, verbose: bool = False
) -> Dict[str, Any]:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "llama-parse"
version = "0.5.13"
version = "0.5.14"
description = "Parse files into RAG-Optimized formats."
authors = ["Logan Markewich <[email protected]>"]
license = "MIT"
Expand Down
63 changes: 56 additions & 7 deletions tests/test_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,27 +76,29 @@ def test_simple_page_markdown_buffer(markdown_parser: LlamaParse) -> None:
os.environ.get("LLAMA_CLOUD_API_KEY", "") == "",
reason="LLAMA_CLOUD_API_KEY not set",
)
def test_simple_page_with_custom_fs() -> None:
@pytest.mark.asyncio
async def test_simple_page_with_custom_fs() -> None:
parser = LlamaParse(result_type="markdown")
fs = LocalFileSystem()
filepath = os.path.join(
os.path.dirname(__file__), "test_files/attention_is_all_you_need.pdf"
)
result = parser.load_data(filepath, fs=fs)
result = await parser.aload_data(filepath, fs=fs)
assert len(result) == 1


@pytest.mark.skipif(
os.environ.get("LLAMA_CLOUD_API_KEY", "") == "",
reason="LLAMA_CLOUD_API_KEY not set",
)
def test_simple_page_progress_workers() -> None:
@pytest.mark.asyncio
async def test_simple_page_progress_workers() -> None:
parser = LlamaParse(result_type="markdown", show_progress=True, verbose=True)

filepath = os.path.join(
os.path.dirname(__file__), "test_files/attention_is_all_you_need.pdf"
)
result = parser.load_data([filepath, filepath])
result = await parser.aload_data([filepath, filepath])
assert len(result) == 2
assert len(result[0].text) > 0

Expand All @@ -107,7 +109,7 @@ def test_simple_page_progress_workers() -> None:
filepath = os.path.join(
os.path.dirname(__file__), "test_files/attention_is_all_you_need.pdf"
)
result = parser.load_data([filepath, filepath])
result = await parser.aload_data([filepath, filepath])
assert len(result) == 2
assert len(result[0].text) > 0

Expand All @@ -116,12 +118,59 @@ def test_simple_page_progress_workers() -> None:
os.environ.get("LLAMA_CLOUD_API_KEY", "") == "",
reason="LLAMA_CLOUD_API_KEY not set",
)
def test_custom_client() -> None:
@pytest.mark.asyncio
async def test_custom_client() -> None:
custom_client = AsyncClient(verify=False, timeout=10)
parser = LlamaParse(result_type="markdown", custom_client=custom_client)
filepath = os.path.join(
os.path.dirname(__file__), "test_files/attention_is_all_you_need.pdf"
)
result = parser.load_data(filepath)
result = await parser.aload_data(filepath)
assert len(result) == 1
assert len(result[0].text) > 0


@pytest.mark.skipif(
os.environ.get("LLAMA_CLOUD_API_KEY", "") == "",
reason="LLAMA_CLOUD_API_KEY not set",
)
@pytest.mark.asyncio
async def test_input_url() -> None:
parser = LlamaParse(result_type="markdown")

# links to a resume example
input_url = "https://cdn-blog.novoresume.com/articles/google-docs-resume-templates/basic-google-docs-resume.png"
result = await parser.aload_data(input_url)
assert len(result) == 1
assert "your name" in result[0].text.lower()


@pytest.mark.skipif(
os.environ.get("LLAMA_CLOUD_API_KEY", "") == "",
reason="LLAMA_CLOUD_API_KEY not set",
)
@pytest.mark.asyncio
async def test_input_url_with_website_input() -> None:
parser = LlamaParse(result_type="markdown")
input_url = "https://www.google.com"
result = await parser.aload_data(input_url)
assert len(result) == 1
assert "google" in result[0].text.lower()


@pytest.mark.skipif(
os.environ.get("LLAMA_CLOUD_API_KEY", "") == "",
reason="LLAMA_CLOUD_API_KEY not set",
)
@pytest.mark.asyncio
async def test_mixing_input_types() -> None:
parser = LlamaParse(result_type="markdown")
filepath = os.path.join(
os.path.dirname(__file__), "test_files/attention_is_all_you_need.pdf"
)
input_url = "https://www.google.com"
result = await parser.aload_data([filepath, input_url])

assert len(result) == 2
assert "table 2" in result[0].text.lower()
assert "google" in result[1].text.lower()

0 comments on commit eeabf48

Please sign in to comment.