Skip to main content

RELLM

RELLM is a library that wraps local Hugging Face pipeline models for structured decoding.

It works by generating tokens one at a time. At each step, it masks tokens that don't conform to the provided partial regular expression.

Warning - this module is still experimental

!pip install rellm > /dev/null

Hugging Face Baseline

First, let's establish a qualitative baseline by checking the output of the model without structured decoding.

import logging

logging.basicConfig(level=logging.ERROR)
prompt = """Human: "What's the capital of the United States?"
AI Assistant:{
"action": "Final Answer",
"action_input": "The capital of the United States is Washington D.C."
}
Human: "What's the capital of Pennsylvania?"
AI Assistant:{
"action": "Final Answer",
"action_input": "The capital of Pennsylvania is Harrisburg."
}
Human: "What 2 + 5?"
AI Assistant:{
"action": "Final Answer",
"action_input": "2 + 5 = 7."
}
Human: 'What's the capital of Maryland?'
AI Assistant:"""
from transformers import pipeline
from langchain.llms import HuggingFacePipeline

hf_model = pipeline(
"text-generation", model="cerebras/Cerebras-GPT-590M", max_new_tokens=200
)

original_model = HuggingFacePipeline(pipeline=hf_model)

generated = original_model.generate([prompt], stop=["Human:"])
print(generated)
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


generations=[[Generation(text=' "What\'s the capital of Maryland?"\n', generation_info=None)]] llm_output=None

That's not so impressive, is it? It didn't answer the question and it didn't follow the JSON format at all! Let's try with the structured decoder.

RELLM LLM Wrapper

Let's try that again, now providing a regex to match the JSON structured format.

import regex  # Note this is the regex library NOT python's re stdlib module

# We'll choose a regex that matches to a structured json string that looks like:
# {
# "action": "Final Answer",
# "action_input": string or dict
# }
pattern = regex.compile(
r'\{\s*"action":\s*"Final Answer",\s*"action_input":\s*(\{.*\}|"[^"]*")\s*\}\nHuman:'
)
from langchain.experimental.llms import RELLM

model = RELLM(pipeline=hf_model, regex=pattern, max_new_tokens=200)

generated = model.predict(prompt, stop=["Human:"])
print(generated)
{"action": "Final Answer",
"action_input": "The capital of Maryland is Baltimore."
}

Voila! Free of parsing errors.