mirror of
https://github.com/AdaCore/ada-eval.git
synced 2026-02-12 13:53:19 -08:00
604 lines
25 KiB
Python
604 lines
25 KiB
Python
import json
|
|
import re
|
|
import shutil
|
|
from logging import WARN
|
|
from pathlib import Path
|
|
|
|
import pydantic
|
|
import pytest
|
|
from helpers import assert_log, setup_git_repo
|
|
|
|
from ada_eval.datasets import Dataset, dataset_has_sample_type
|
|
from ada_eval.datasets.loader import (
|
|
DuplicateSampleNameError,
|
|
InvalidDatasetError,
|
|
InvalidDatasetNameError,
|
|
MixedDatasetFormatsError,
|
|
MixedSampleTypesError,
|
|
UnknownDatasetKindError,
|
|
load_datasets,
|
|
load_packed_dataset,
|
|
load_unpacked_dataset,
|
|
)
|
|
from ada_eval.datasets.types.directory_contents import DirectoryContents
|
|
from ada_eval.datasets.types.evaluation_stats import (
|
|
EvaluationStatsBuild,
|
|
EvaluationStatsFailed,
|
|
EvaluationStatsProve,
|
|
EvaluationStatsTimedOut,
|
|
ProofCheck,
|
|
)
|
|
from ada_eval.datasets.types.samples import (
|
|
EVALUATED_SAMPLE_TYPES,
|
|
GENERATED_SAMPLE_TYPES,
|
|
AdaSample,
|
|
EvaluatedAdaSample,
|
|
EvaluatedSample,
|
|
ExplainSample,
|
|
GeneratedSample,
|
|
GenerationStats,
|
|
Location,
|
|
Sample,
|
|
SampleKind,
|
|
SampleStage,
|
|
SparkSample,
|
|
)
|
|
from ada_eval.utils import UnexpectedTypeError
|
|
|
|
|
|
def expected_base_sample_fields(
|
|
sample_name: str, dataset_dirname: str
|
|
) -> dict[str, object]:
|
|
"""Return expected fields common to (almost) all `Samples` in the test datasets."""
|
|
return {
|
|
"name": sample_name,
|
|
"location": Location(
|
|
path=Path("source_file_0"), subprogram_name="My_Subprogram"
|
|
),
|
|
"prompt": (
|
|
f"This is the prompt for sample '{sample_name}' from dataset "
|
|
f"'{dataset_dirname}'.\n"
|
|
),
|
|
"sources": DirectoryContents(
|
|
{
|
|
Path("source_file_0"): (
|
|
f"This is 'source_file_0' in sample '{sample_name}' from "
|
|
f"dataset '{dataset_dirname}'.\n"
|
|
).encode()
|
|
}
|
|
),
|
|
"comments": (
|
|
f"This is a comment on sample '{sample_name}' from dataset "
|
|
f"'{dataset_dirname}'.\n"
|
|
),
|
|
"canonical_evaluation_results": [],
|
|
}
|
|
|
|
|
|
def expected_explain_sample(sample_name: str, dataset_dirname: str) -> ExplainSample:
|
|
"""Return an `ExplainSample` matching that expected from the test datasets."""
|
|
return ExplainSample(
|
|
**expected_base_sample_fields(sample_name, dataset_dirname),
|
|
canonical_solution=(
|
|
f"This is the reference answer for sample '{sample_name}' from "
|
|
f"dataset '{dataset_dirname}'.\n"
|
|
),
|
|
correct_statements=[
|
|
"This is a correct statement.",
|
|
"This is another correct statement.",
|
|
],
|
|
incorrect_statements=[
|
|
"This is an incorrect statement.",
|
|
"This is another incorrect statement.",
|
|
],
|
|
)
|
|
|
|
|
|
def expected_ada_sample(sample_name: str, dataset_dirname: str) -> AdaSample:
|
|
"""Return an `AdaSample` mostly matching those expected from the test datasets."""
|
|
return AdaSample(
|
|
**expected_base_sample_fields(sample_name, dataset_dirname),
|
|
canonical_solution=DirectoryContents(
|
|
{
|
|
Path("source_file_0"): (
|
|
f"This is 'source_file_0' in sample '{sample_name}' from "
|
|
f"dataset '{dataset_dirname}'.\nThis is a new line added as "
|
|
"part of the canonical solution.\n"
|
|
).encode()
|
|
}
|
|
),
|
|
unit_tests=DirectoryContents(
|
|
{
|
|
Path("unit_test_file_0"): (
|
|
f"This is a unit test for sample '{sample_name}' from dataset "
|
|
f"'{dataset_dirname}'.\n"
|
|
).encode()
|
|
}
|
|
),
|
|
)
|
|
|
|
|
|
def expected_spark_sample(sample_name: str, dataset_dirname: str) -> SparkSample:
|
|
"""Return a `SparkSample` mostly matching those expected from the test datasets."""
|
|
return SparkSample(
|
|
**expected_ada_sample(sample_name, dataset_dirname).model_dump(),
|
|
required_checks=[
|
|
ProofCheck(
|
|
rule="RULE_NAME",
|
|
entity_name="My_Package.My_Subprogram",
|
|
src_pattern="pattern",
|
|
)
|
|
],
|
|
)
|
|
|
|
|
|
def expected_generated_sample(base_sample: Sample) -> GeneratedSample:
|
|
"""Return the expected `GeneratedSample` corresponding to a base sample."""
|
|
if isinstance(base_sample, AdaSample):
|
|
generated_solution: object = DirectoryContents(
|
|
base_sample.sources.files
|
|
| {Path("generated_file"): b"This file was added during generation\n"}
|
|
)
|
|
else:
|
|
generated_solution = "This is the generated explanation."
|
|
return GENERATED_SAMPLE_TYPES[base_sample.kind](
|
|
**base_sample.model_dump(),
|
|
generation_stats=GenerationStats(
|
|
exit_code=0,
|
|
stdout="This is the generation's stdout\n",
|
|
stderr="",
|
|
runtime_ms=0,
|
|
),
|
|
generated_solution=generated_solution,
|
|
)
|
|
|
|
|
|
def expected_evaluated_sample(base_sample: Sample) -> EvaluatedSample:
|
|
"""Return the expected `EvaluatedSample` corresponding to a base sample."""
|
|
generated_sample = expected_generated_sample(base_sample)
|
|
return EVALUATED_SAMPLE_TYPES[generated_sample.kind](
|
|
**generated_sample.model_dump(),
|
|
evaluation_results=[
|
|
EvaluationStatsProve(
|
|
result="unproved",
|
|
proved_checks={"PROVED_CHECK_NAME": 1},
|
|
unproved_checks={"UNPROVED_CHECK_NAME": 2},
|
|
warnings={"WARNING_NAME": 3},
|
|
non_spark_entities=["Entity_Name"],
|
|
missing_required_checks=[
|
|
ProofCheck(
|
|
rule="RULE_NAME",
|
|
entity_name="My_Package.My_Subprogram",
|
|
src_pattern="pattern",
|
|
)
|
|
],
|
|
pragma_assume_count=4,
|
|
),
|
|
EvaluationStatsBuild(
|
|
compiled=False, pre_format_warnings=True, post_format_warnings=True
|
|
),
|
|
],
|
|
)
|
|
|
|
|
|
def check_loaded_datasets(
|
|
datasets: list[Dataset[Sample]], stage: SampleStage = SampleStage.INITIAL
|
|
) -> None:
|
|
"""Check that `datasets` matches `tests/data/valid_[base/generated]_datasets`."""
|
|
|
|
def promoted_if_needed(sample: Sample) -> Sample:
|
|
match stage:
|
|
case SampleStage.INITIAL:
|
|
return sample
|
|
case SampleStage.GENERATED:
|
|
return expected_generated_sample(sample)
|
|
case SampleStage.EVALUATED:
|
|
return expected_evaluated_sample(sample)
|
|
|
|
assert len(datasets) == 3
|
|
datasets_by_name = {d.dirname: d for d in datasets}
|
|
|
|
# Check the Explain dataset
|
|
explain_dataset = datasets_by_name["explain_test"]
|
|
assert explain_dataset.name == "test"
|
|
assert explain_dataset.kind is SampleKind.EXPLAIN
|
|
assert explain_dataset.stage is stage
|
|
assert explain_dataset.samples == [
|
|
promoted_if_needed(expected_explain_sample("test_sample_0", "explain_test"))
|
|
]
|
|
|
|
# Construct the expected sample for the Ada dataset (i.e. that returned by
|
|
# `expected_ada_sample()`, except with some `canonical_evaluation_results`)
|
|
expected_ada_sample_0 = expected_ada_sample("test_sample_0", "ada_test")
|
|
expected_ada_sample_0.canonical_evaluation_results = [
|
|
EvaluationStatsBuild(
|
|
compiled=True, pre_format_warnings=True, post_format_warnings=False
|
|
)
|
|
]
|
|
# Check the Ada dataset
|
|
ada_dataset = datasets_by_name["ada_test"]
|
|
assert ada_dataset.name == "test"
|
|
assert ada_dataset.kind is SampleKind.ADA
|
|
assert ada_dataset.stage is stage
|
|
assert ada_dataset.samples == [promoted_if_needed(expected_ada_sample_0)]
|
|
|
|
# Construct expected samples for the Spark dataset (sample_0 is mostly
|
|
# as returned by `expected_spark_sample()`, sample_1 has some extra files,
|
|
# and sample_2 is empty apart from a minimal `other.json` file, so should be
|
|
# populated with defaults)
|
|
expected_spark_sample_0 = expected_spark_sample("test_sample_0", "spark_test")
|
|
expected_spark_sample_0.canonical_evaluation_results = [
|
|
EvaluationStatsTimedOut(
|
|
eval="prove", cmd_timed_out=["cmd", "arg0", "arg1"], timeout=12.34
|
|
),
|
|
EvaluationStatsFailed(eval="build", exception='SomeError("Some message")'),
|
|
]
|
|
expected_spark_sample_1 = expected_spark_sample("test_sample_1", "spark_test")
|
|
expected_spark_sample_1.sources.files[Path("source_dir_0/source_file_1")] = (
|
|
b"This is 'source_file_1' in sample 'test_sample_1' from dataset "
|
|
b"'spark_test'.\n"
|
|
)
|
|
expected_spark_sample_1.canonical_solution.files[
|
|
Path("source_dir_1/source_file_2")
|
|
] = (
|
|
b"This is 'source_file_2' in sample 'test_sample_1' from dataset "
|
|
b"'spark_test'.\nThe addition of this file is part of the canonical solution.\n"
|
|
)
|
|
expected_spark_sample_1.canonical_evaluation_results = [
|
|
EvaluationStatsProve(
|
|
result="proved",
|
|
proved_checks={"PROVED_CHECK_NAME": 123},
|
|
unproved_checks={},
|
|
warnings={},
|
|
non_spark_entities=[],
|
|
missing_required_checks=[],
|
|
pragma_assume_count=0,
|
|
),
|
|
EvaluationStatsBuild(
|
|
compiled=True, pre_format_warnings=False, post_format_warnings=False
|
|
),
|
|
]
|
|
expected_spark_sample_2 = SparkSample(
|
|
name="test_sample_2",
|
|
location=Location(
|
|
path=Path("non/existent/path"), subprogram_name="My_Subprogram"
|
|
),
|
|
prompt="",
|
|
sources=DirectoryContents({}),
|
|
canonical_solution=DirectoryContents({}),
|
|
canonical_evaluation_results=[],
|
|
comments="",
|
|
unit_tests=DirectoryContents({}),
|
|
)
|
|
# Check the Spark dataset (note that the sample ordering is not guaranteed)
|
|
spark_dataset = datasets_by_name["spark_test"]
|
|
assert spark_dataset.name == "test"
|
|
assert spark_dataset.kind is SampleKind.SPARK
|
|
assert spark_dataset.stage is stage
|
|
spark_samples_by_name = {s.name: s for s in spark_dataset.samples}
|
|
assert spark_samples_by_name == {
|
|
"test_sample_0": promoted_if_needed(expected_spark_sample_0),
|
|
"test_sample_1": promoted_if_needed(expected_spark_sample_1),
|
|
"test_sample_2": promoted_if_needed(expected_spark_sample_2),
|
|
}
|
|
|
|
|
|
def test_load_valid_unpacked_datasets(expanded_test_datasets: Path):
|
|
"""Check that loading unpacked datasets works correctly."""
|
|
check_loaded_datasets(load_datasets(expanded_test_datasets))
|
|
|
|
|
|
def test_load_valid_packed_datasets(compacted_test_datasets: Path):
|
|
"""Check that loading packed datasets works correctly."""
|
|
check_loaded_datasets(load_datasets(compacted_test_datasets))
|
|
|
|
|
|
def test_load_valid_packed_generated_datasets(generated_test_datasets: Path):
|
|
"""Check that loading packed generated datasets works correctly."""
|
|
check_loaded_datasets(load_datasets(generated_test_datasets), SampleStage.GENERATED)
|
|
|
|
|
|
def test_load_valid_packed_evaluated_datasets(evaluated_test_datasets: Path):
|
|
"""Check that loading packed evaluated datasets works correctly."""
|
|
check_loaded_datasets(load_datasets(evaluated_test_datasets), SampleStage.EVALUATED)
|
|
|
|
|
|
def test_load_valid_unpacked_datasets_with_gitignore(expanded_test_datasets: Path):
|
|
"""Check loading unpacked datasets with `.gitignore` files works correctly."""
|
|
# Create a `.gitignore` file which ignores any `obj/` directories
|
|
gitignore_path = expanded_test_datasets / ".gitignore"
|
|
gitignore_path.write_text("obj/\n")
|
|
# Add some files in `obj/` directories to the dataset
|
|
ada_sample_0_dir = expanded_test_datasets / "ada_test" / "test_sample_0"
|
|
spark_sample_1_dir = expanded_test_datasets / "spark_test" / "test_sample_1"
|
|
for source_dir in (
|
|
ada_sample_0_dir / "base",
|
|
spark_sample_1_dir / "solution" / "source_dir_1",
|
|
):
|
|
(source_dir / "obj").mkdir(exist_ok=True)
|
|
(source_dir / "obj" / "some_file").write_text("This is a test file.\n")
|
|
|
|
# Load and check the datasets
|
|
datasets = load_datasets(expanded_test_datasets)
|
|
|
|
# Check that the `obj/some_file` files are present in the loaded datasets
|
|
# because `expanded_test_dataset` is not in a Git worktree.
|
|
ada_dataset = next(d for d in datasets if d.sample_type is AdaSample)
|
|
assert ada_dataset.samples[0].name == "test_sample_0"
|
|
ada_ignored_rel_path = Path("obj/some_file")
|
|
assert ada_ignored_rel_path in ada_dataset.samples[0].sources.files
|
|
assert ada_dataset.samples[0].sources.files[ada_ignored_rel_path] == (
|
|
b"This is a test file.\n"
|
|
)
|
|
spark_dataset = next(d for d in datasets if dataset_has_sample_type(d, SparkSample))
|
|
spark_sample_1 = next(s for s in spark_dataset.samples if s.name == "test_sample_1")
|
|
spark_ignored_rel_path = Path("source_dir_1/obj/some_file")
|
|
assert spark_ignored_rel_path in spark_sample_1.canonical_solution.files
|
|
assert spark_sample_1.canonical_solution.files[spark_ignored_rel_path] == (
|
|
b"This is a test file.\n"
|
|
)
|
|
# The loaded datasets should otherwise be as expected
|
|
ada_dataset.samples[0].sources.files.pop(ada_ignored_rel_path)
|
|
spark_sample_1.canonical_solution.files.pop(spark_ignored_rel_path)
|
|
check_loaded_datasets(datasets)
|
|
|
|
# Initialise a Git repository in the dataset directory and reload
|
|
setup_git_repo(expanded_test_datasets)
|
|
datasets = load_datasets(expanded_test_datasets)
|
|
|
|
# Check that this time the `obj/some_file` files are not present in the
|
|
# loaded datasets
|
|
ada_dataset = next(d for d in datasets if d.sample_type is AdaSample)
|
|
assert ada_dataset.samples[0].name == "test_sample_0"
|
|
assert ada_ignored_rel_path not in ada_dataset.samples[0].sources.files
|
|
spark_dataset = next(d for d in datasets if dataset_has_sample_type(d, SparkSample))
|
|
spark_sample_1 = next(s for s in spark_dataset.samples if s.name == "test_sample_1")
|
|
assert spark_ignored_rel_path not in spark_sample_1.canonical_solution.files
|
|
check_loaded_datasets(datasets)
|
|
|
|
|
|
def test_load_empty_packed_dataset(
|
|
tmp_path: Path, capsys: pytest.CaptureFixture[str], caplog: pytest.LogCaptureFixture
|
|
):
|
|
"""Check that loading an empty jsonl file yields an appropriate empty dataset."""
|
|
empty_dataset_path = tmp_path / "ada_empty.jsonl"
|
|
empty_dataset_path.touch()
|
|
datasets = load_datasets(empty_dataset_path)
|
|
assert len(datasets) == 1
|
|
empty_dataset = datasets[0]
|
|
assert empty_dataset.name == "empty"
|
|
assert empty_dataset.sample_type is EvaluatedAdaSample
|
|
assert empty_dataset.samples == []
|
|
assert_log(caplog, WARN, f"Dataset at '{empty_dataset_path}' is empty.")
|
|
output = capsys.readouterr()
|
|
assert output.out == ""
|
|
assert output.err == ""
|
|
|
|
|
|
def test_load_no_valid_samples(
|
|
compacted_test_datasets: Path,
|
|
expanded_test_datasets: Path,
|
|
caplog: pytest.LogCaptureFixture,
|
|
):
|
|
"""Check that loading a dataset with no valid samples produces a warning."""
|
|
# Remove the `.jsonl` suffix from all packed datasets and check that loading
|
|
# them issues a warning
|
|
for packed_dataset in compacted_test_datasets.iterdir():
|
|
packed_dataset.rename(packed_dataset.with_suffix(""))
|
|
datasets = load_datasets(compacted_test_datasets)
|
|
assert len(datasets) == 0
|
|
assert_log(
|
|
caplog, WARN, f"No datasets could be found at: {compacted_test_datasets}"
|
|
)
|
|
# Remove the `other.json` file from all unpacked datasets and check that
|
|
# loading them issues a warning
|
|
for other_json in expanded_test_datasets.glob("**/other.json"):
|
|
other_json.unlink()
|
|
datasets = load_datasets(expanded_test_datasets)
|
|
assert len(datasets) == 0
|
|
assert_log(caplog, WARN, f"No datasets could be found at: {expanded_test_datasets}")
|
|
|
|
|
|
def test_load_mixed_datasets(
|
|
tmp_path: Path, compacted_test_datasets: Path, expanded_test_datasets: Path
|
|
):
|
|
"""Check that loading directory with mixed packed/unpacked datasets gives error."""
|
|
mixed_dir = tmp_path / "mixed"
|
|
mixed_dir.mkdir()
|
|
for fixture_dir in (compacted_test_datasets, expanded_test_datasets):
|
|
for dataset in fixture_dir.iterdir():
|
|
shutil.move(dataset, mixed_dir / (dataset.relative_to(fixture_dir)))
|
|
error_msg = f"'{mixed_dir}' contains a mixture of packed and unpacked datasets."
|
|
with pytest.raises(MixedDatasetFormatsError, match=re.escape(error_msg)):
|
|
load_datasets(mixed_dir)
|
|
|
|
|
|
def test_load_invalid_samples(
|
|
compacted_test_datasets: Path, expanded_test_datasets: Path
|
|
):
|
|
"""Test that exceptions while loading samples specify which sample raised them."""
|
|
# Make one of the samples in `spark_test.jsonl` invalid
|
|
spark_test_path = compacted_test_datasets / "spark_test.jsonl"
|
|
original_spark_test = spark_test_path.read_text()
|
|
spark_test_path.write_text(
|
|
original_spark_test.replace('"name":"test_sample_1",', "")
|
|
)
|
|
error_msg = r"^1 validation error for SparkSample\nname\n Field required .*\n.*\n"
|
|
error_msg += re.escape(
|
|
f"This error occurred while parsing line 2 of '{spark_test_path}'"
|
|
)
|
|
with pytest.raises(pydantic.ValidationError, match=error_msg):
|
|
load_datasets(compacted_test_datasets)
|
|
|
|
# Make an `other.json` file invalid JSON and check the resulting error
|
|
sample_dir = expanded_test_datasets / "explain_test" / "test_sample_0"
|
|
other_json_path = sample_dir / "other.json"
|
|
original_other_json = other_json_path.read_text()
|
|
other_json_path.write_text("This is not valid JSON")
|
|
location_note = (
|
|
f"\nThis exception occurred while loading the sample at: {sample_dir}"
|
|
)
|
|
error_msg = "Expecting value: line 1 column 1 (char 0)" + location_note
|
|
with pytest.raises(json.decoder.JSONDecodeError, match=re.escape(error_msg)):
|
|
load_datasets(expanded_test_datasets)
|
|
# Make the `other.json` file valid JSON, but with the wrong top-level type
|
|
other_json_path.write_text("null")
|
|
error_msg = "Expected type dict, but got NoneType." + location_note
|
|
with pytest.raises(UnexpectedTypeError, match=re.escape(error_msg)):
|
|
load_datasets(expanded_test_datasets)
|
|
# Restore the `other.json` file, but with an invalid `Location`
|
|
other_json_invalid_location = json.loads(original_other_json)
|
|
other_json_invalid_location["location"].pop("path")
|
|
other_json_path.write_text(json.dumps(other_json_invalid_location))
|
|
error_msg = r"^1 validation error for Location\npath\n Field required .*\n.*"
|
|
error_msg += re.escape(location_note)
|
|
with pytest.raises(pydantic.ValidationError, match=error_msg):
|
|
load_datasets(expanded_test_datasets)
|
|
|
|
|
|
def test_load_invalid_sample_name(
|
|
compacted_test_datasets: Path, expanded_test_datasets: Path
|
|
):
|
|
"""Check that loading a dataset with an invalid sample name raises an error."""
|
|
# Check packed
|
|
spark_test_path = compacted_test_datasets / "spark_test.jsonl"
|
|
original_spark_test = spark_test_path.read_text()
|
|
spark_test_path.write_text(
|
|
original_spark_test.replace('"name":"test_sample_1"', '"name":"test_sample_#1"')
|
|
)
|
|
error_msg = re.escape(
|
|
"1 validation error for SparkSample\nname\n"
|
|
" Value error, Invalid sample name: 'test_sample_#1'. Please only use "
|
|
"alphanumeric characters, hyphens, and underscores. "
|
|
)
|
|
error_msg += r".*\n.*\n"
|
|
error_msg += re.escape(
|
|
f"This error occurred while parsing line 2 of '{spark_test_path}'"
|
|
)
|
|
with pytest.raises(pydantic.ValidationError, match=error_msg):
|
|
load_datasets(compacted_test_datasets)
|
|
|
|
# Check unpacked
|
|
ada_sample_path = expanded_test_datasets / "ada_test" / "test_sample_0"
|
|
ada_sample_path = ada_sample_path.rename(ada_sample_path.parent / "test_sample_#0")
|
|
error_msg = re.escape(
|
|
"1 validation error for Sample\nname\n"
|
|
" Value error, Invalid sample name: 'test_sample_#0'. Please only use "
|
|
"alphanumeric characters, hyphens, and underscores. "
|
|
)
|
|
error_msg += r".*\n.*\n"
|
|
error_msg = (
|
|
f"This exception occurred while loading the sample at: {ada_sample_path}"
|
|
)
|
|
with pytest.raises(pydantic.ValidationError, match=error_msg):
|
|
load_datasets(expanded_test_datasets)
|
|
|
|
|
|
def test_load_non_sample_warning(
|
|
expanded_test_datasets: Path, caplog: pytest.LogCaptureFixture
|
|
):
|
|
"""Check that a non-sample file/directory is ignored with a warning."""
|
|
# Remove the `other.json` file from one of the spark samples and check that
|
|
# loading the dataset issues a warning
|
|
spark_sample_path = expanded_test_datasets / "spark_test" / "test_sample_0"
|
|
(spark_sample_path / "other.json").unlink()
|
|
datasets = load_datasets(expanded_test_datasets)
|
|
spark_dataset = next(d for d in datasets if dataset_has_sample_type(d, SparkSample))
|
|
assert len(spark_dataset.samples) == 2
|
|
assert_log(caplog, WARN, f"Skipping non-sample directory: {spark_sample_path}")
|
|
|
|
|
|
def test_load_invalid_dataset_name(
|
|
compacted_test_datasets: Path, expanded_test_datasets: Path
|
|
):
|
|
"""Check that loading a dataset with an invalid name format raises an error."""
|
|
shutil.move(
|
|
compacted_test_datasets / "ada_test.jsonl",
|
|
compacted_test_datasets / "ada test.jsonl",
|
|
)
|
|
error_msg = "Expected packed dataset filename to contain an underscore:"
|
|
with pytest.raises(InvalidDatasetNameError, match=error_msg):
|
|
load_datasets(compacted_test_datasets)
|
|
shutil.move(
|
|
expanded_test_datasets / "ada_test",
|
|
expanded_test_datasets / "ada test",
|
|
)
|
|
error_msg = "Expected unpacked dataset dir name to contain an underscore:"
|
|
with pytest.raises(InvalidDatasetNameError, match=error_msg):
|
|
load_datasets(expanded_test_datasets)
|
|
|
|
|
|
def test_load_invalid_dataset_kind(
|
|
compacted_test_datasets: Path, expanded_test_datasets: Path
|
|
):
|
|
"""Check that loading a dataset with an invalid kind raises an error."""
|
|
shutil.move(
|
|
compacted_test_datasets / "ada_test.jsonl",
|
|
compacted_test_datasets / "unknown_test.jsonl",
|
|
)
|
|
error_msg = "Unknown dataset kind: unknown"
|
|
with pytest.raises(UnknownDatasetKindError, match=re.escape(error_msg)):
|
|
load_datasets(compacted_test_datasets)
|
|
shutil.move(
|
|
expanded_test_datasets / "ada_test",
|
|
expanded_test_datasets / "unknown_test",
|
|
)
|
|
with pytest.raises(UnknownDatasetKindError, match=re.escape(error_msg)):
|
|
load_datasets(expanded_test_datasets)
|
|
|
|
|
|
def test_load_duplicate_sample_names(compacted_test_datasets: Path):
|
|
"""Check that loading a dataset with duplicate sample names raises an error."""
|
|
spark_dataset_file = compacted_test_datasets / "spark_test.jsonl"
|
|
spark_dataset_packed_content = spark_dataset_file.read_text()
|
|
spark_dataset_packed_content = spark_dataset_packed_content.replace(
|
|
'"name":"test_sample_0"', '"name":"test_sample_1"'
|
|
)
|
|
spark_dataset_file.write_text(spark_dataset_packed_content)
|
|
error_msg = f"Duplicate sample name 'test_sample_1' found in '{spark_dataset_file}'"
|
|
with pytest.raises(DuplicateSampleNameError, match=re.escape(error_msg)):
|
|
load_datasets(compacted_test_datasets)
|
|
|
|
|
|
def test_load_mixed_sample_types(generated_test_datasets: Path):
|
|
"""Check that loading a dataset with mixed sample types raises an error."""
|
|
spark_dataset_file = generated_test_datasets / "spark_test.jsonl"
|
|
spark_dataset_content = spark_dataset_file.read_text()
|
|
# Change one of the generated samples to an evaluated one
|
|
spark_dataset_content = spark_dataset_content.replace(
|
|
'"name":"test_sample_2"',
|
|
'"name":"test_sample_2","evaluation_results":[]',
|
|
)
|
|
spark_dataset_file.write_text(spark_dataset_content)
|
|
error_msg = (
|
|
f"Dataset at '{spark_dataset_file}' contains mixed sample types:\n"
|
|
f"'test_sample_0' is GeneratedSparkSample "
|
|
f"but 'test_sample_2' is EvaluatedSparkSample"
|
|
)
|
|
with pytest.raises(MixedSampleTypesError, match=re.escape(error_msg)):
|
|
load_datasets(generated_test_datasets)
|
|
|
|
|
|
def test_load_unpacked_dataset_invalid(expanded_test_datasets: Path):
|
|
"""Check that `load_unpacked_dataset()` defensively raises on an invalid dataset."""
|
|
# Remove the `other.json` file from the ada dataset
|
|
ada_dataset_path = expanded_test_datasets / "ada_test"
|
|
(ada_dataset_path / "test_sample_0" / "other.json").unlink()
|
|
# Attempting to load this dataset directly with `load_unpacked_dataset`
|
|
# should raise an `InvalidDatasetError`
|
|
error_msg = f"'{ada_dataset_path}' is not a valid unpacked dataset"
|
|
with pytest.raises(InvalidDatasetError, match=re.escape(error_msg)):
|
|
load_unpacked_dataset(ada_dataset_path)
|
|
|
|
|
|
def test_load_packed_dataset_invalid(compacted_test_datasets: Path):
|
|
"""Check that `load_packed_dataset()` defensively raises on an invalid dataset."""
|
|
# Remove the `.jsonl` suffix from the ada dataset
|
|
ada_dataset_path = compacted_test_datasets / "ada_test.jsonl"
|
|
ada_dataset_path = ada_dataset_path.rename(compacted_test_datasets / "ada_test")
|
|
# Attempting to load this dataset directly with `load_packed_dataset`
|
|
# should raise an `InvalidDatasetError`
|
|
error_msg = f"'{ada_dataset_path}' is not a valid packed dataset"
|
|
with pytest.raises(InvalidDatasetError, match=re.escape(error_msg)):
|
|
load_packed_dataset(ada_dataset_path)
|