RAG with LangChain: answering questions from press releases using LLMs!
RAG (Retrieval Augmented Generation) is a technique that leverages the hidden layer representations of LLM (Large Language Models) to create context vector databases that can be searched across and fed into LLM input queries.
This allows us to artificially increase our model’s “context”, without any modification to the models maximum token length or need to perform computationally expensive fine-tuning. That is, we can dynamically update an LLM’s “knowledge base” with our own data!
This is a simple diagram of what our RAG design is going to look like:
Let’s start by installing some packages, we’ll be using LangChain to orchestrate most of the RAG flow, the OpenAI API as our LLM and ChromaDB as our vector database:
%pip install --upgrade --quiet langchain langchain-community langchainhub langchain-openai chromadb bs4
After you have done this, make sure you have your OpenAI API key available as an environment variable, alternatively create it using the os.environ command in your script (NB: this is not best practice!).
import os
os.environ["OPENAI_API_KEY"] = "your_open_ai_api_key"
We are now going to import the different helper functions we will use for this project from LangChain.
import bs4
from langchain import hub
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.vectorstores import Chroma
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import ChatOpenAI, OpenAIEmbedding
Now comes the fun part, let’s create a list of links from which we are going to extract the post content, to later parse it and build our vector database.
I am going to be using this very interesting article about competition in the housing market in the United Kingdom.
# Load, chunk and index the contents of the blog.
from langchain_community.document_loaders import AsyncHtmlLoader
urls = ["https://www.gov.uk/government/news/cma-finds-fundamental-concerns-in-housebuilding-market"]
loader = WebBaseLoader(
web_paths=(urls),
)
docs = loader.load()
Once we have created our document loader, we are going to break into smaller chunks that we can pass into our LLM’s limited token context window using the RecursiveCharacterTextSplitter(...)
class.
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
splits = text_splitter.split_documents(docs)
We now construct our vector store:
vectorstore = Chroma.from_documents(documents=splits, embedding=OpenAIEmbeddings(model="text-embedding-3-large"))
And then build our RAG’s pipeline elements, the retriever from which we will be getting our context, the prompt combining context and query and the LLM.
# Retrieve and generate using the relevant snippets of the blog.
retriever = vectorstore.as_retriever()
prompt = hub.pull("rlm/rag-prompt")
llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0)
If you want to see what the prompt we are using looks like, check it out here! The prompt instructs the model to use the user’s input query and the different relevant context chunks we have found to produce an output.
Once we have done this, we can build our rag_chain object from which we will be sampling!
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)
rag_chain = (
{"context": retriever | format_docs, "question": RunnablePassthrough()}
| prompt
| llm
| StrOutputParser()
)
We can now call it with a question about the web we supplied:
rag_chain.invoke("What issues did the CMA found with the housing market?")
The CMA found issues with the planning system and speculative private development leading to under-delivery of new homes. Homeowners faced high and unclear estate management charges, and there were concerns about the quality of new housing due to snagging issues. The CMA also opened an investigation into housebuilders suspected of sharing commercially sensitive information that could influence build-out of sites and new home prices.
What if we want to know where the LLM got the answer from?
LangChain allows retrieval of the context blocks that were originally passed to the LLM by altering the RunnablePassthrough class, which when called with assign (RunnablePassthrough.assign(...)
) will add any extra arguments within the assign function.
from langchain_core.runnables import RunnableParallel
rag_chain_from_docs = (
RunnablePassthrough.assign(context=(lambda x: format_docs(x["context"])))
| prompt
| llm
| StrOutputParser()
)
rag_chain_with_source = RunnableParallel(
{"context": retriever, "question": RunnablePassthrough()}
).assign(answer=rag_chain_from_docs)
rag_chain_with_source.invoke("What issues did the CMA found with the housing market?")response.keys()
From this rather than just text we get an object with['context', 'question', 'answer']
, where context contains the chunks in our vector database that had the largest similarities to the user’s query, i.e. the source for the LLMs response!
Uncover more layers below!
💻 Code: https://gist.github.com/phgelado/07a5bd3b2f6682df6286a8bab4e38c1c