Skip to content

LLM Integrations

Bases: Solver

A solver that uses an LLM.

This is for research purposes only. LLM output can be reliable, slow, and expensive, and should not be used as a replacement for a deterministic logic-based solver.

Example:

    from typedlogic import Term
    from typedlogic.integrations.frameworks.pydantic import FactBaseModel
    class AncestorOf(FactBaseModel):
    ...     ancestor: str
    ...     descendant: str
     from typedlogic import SentenceGroup, PredicateDefinition
    solver = LLMSolver(model_name="gpt-4o")
    solver.add_predicate_definition(PredicateDefinition(predicate="AncestorOf", arguments={'ancestor': str, 'descendant': str}))
    solver.add_fact(AncestorOf(ancestor='p1', descendant='p1a'))
    solver.add_fact(AncestorOf(ancestor='p1a', descendant='p1aa'))

    aa = SentenceGroup(name="transitivity-of-ancestor-of")
    solver.add_sentence_group(aa)
    soln = solver.prove(Term("AncestorOf", "p1", "p1aa"))

This makes use of the datasette LLM package. Consult the documentation here for details on how to set up keys, use alternative models, etc.

Source code in src/typedlogic/integrations/solvers/llm/llm_solver.py
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
@dataclass
class LLMSolver(Solver):
    """
    A solver that uses an LLM.

    This is for research purposes only. LLM output can be reliable, slow, and expensive,
    and should not be used as a replacement for a deterministic logic-based solver.

    Example:
    -------
    ```python
        from typedlogic import Term
        from typedlogic.integrations.frameworks.pydantic import FactBaseModel
        class AncestorOf(FactBaseModel):
        ...     ancestor: str
        ...     descendant: str
         from typedlogic import SentenceGroup, PredicateDefinition
        solver = LLMSolver(model_name="gpt-4o")
        solver.add_predicate_definition(PredicateDefinition(predicate="AncestorOf", arguments={'ancestor': str, 'descendant': str}))
        solver.add_fact(AncestorOf(ancestor='p1', descendant='p1a'))
        solver.add_fact(AncestorOf(ancestor='p1a', descendant='p1aa'))

        aa = SentenceGroup(name="transitivity-of-ancestor-of")
        solver.add_sentence_group(aa)
        soln = solver.prove(Term("AncestorOf", "p1", "p1aa"))
    ```

    This makes use of the [datasette LLM](https://llm.datasette.io/) package. Consult the documentation
    here for details on how to set up keys, use alternative models, etc.

    """

    model_name: str = field(default="gpt-4o")
    fol_syntax: str = field(default="fol")
    profile: ClassVar[Profile] = MixedProfile(Unrestricted(), OpenWorld())

    def models(self) -> Iterator[Model]:
        r = self.check()
        if r.satisfiable:
            yield Model()

    def prove(self, sentence: Sentence) -> Optional[bool]:
        results = list(self.prove_multiple([sentence]))
        return results[0][1]

    def prove_multiple(self, sentences: List[Sentence]) -> Iterable[Tuple[Sentence, Optional[bool]]]:
        compiler = get_compiler(self.fol_syntax)
        program = compiler.compile(self.base_theory)
        model = llm.get_model(self.model_name)
        enumerated_goals = dict(enumerate(sentences, 1))
        enumerated_goals_compiled = {i: compiler.compile_sentence(s) for i, s in enumerated_goals.items()}
        goals = "\n".join([f"{i}: {s}" for i, s in enumerated_goals_compiled.items()])
        prompt = TEMPLATE.format(program=program, goals=goals)
        # print(f"SYSTEM={SYSTEM}")
        # print(f"PROMPT={prompt}")
        response = model.prompt(prompt, system=SYSTEM)
        # print(f"RESPONSE={response.text()}")
        obj = self.parse_response(response.text())
        for i in obj["provable"]:
            yield enumerated_goals[int(i)], True
        for i in obj["not_provable"]:
            yield enumerated_goals[int(i)], False

    def parse_response(self, text: str) -> Any:
        if "```" in text:
            text = text.split("```")[1].strip()
            if text.startswith("yaml"):
                text = text[5:].strip()
        return yaml.safe_load(text)

    def check(self) -> Solution:
        return Solution(satisfiable=None)