import streamlit as st
import gc

from langchain_ollama.chat_models import ChatOllama
from langchain_core.output_parsers import StrOutputParser

from utils.embedding_model import init_embedding
from utils.vector_store import get_vector_store
from prompts.hyde import hyde_prompt, hyde_rag_prompt

# RTL support for Persian text
st.markdown("""
<style>
    html, body, [class*="css"] {
        direction: rtl;
        text-align: right;
    }
    button, input, optgroup, select, textarea {
        direction: rtl;
    }
    .stButton button {
        direction: rtl;
    }
    header {
        direction: rtl;
    }
    .main {
        direction: rtl;
    }
    .stChatInput {
        direction: rtl;
    }
</style>
""", unsafe_allow_html=True)

@st.cache_resource
def init_models():
    embedding_model, _ = init_embedding(
        model_name="intfloat/multilingual-e5-large",
        model_kwargs = {"device": "cpu", "trust_remote_code": True},
        test_query = "Embedding Model Initialized Successfully.")

    retriever = get_vector_store(persist_directory="db/store_500", embedding_model=embedding_model, k_arg=1)

    llm = ChatOllama(
        #model="phi3:3.8b-instruct",
        model="gemma2:27b",
        temperature=0.0,
        verbose=True,
        num_predict=256,
    )

    prompt = hyde_prompt()
    generate_docs_for_retrieval = (
        prompt | llm | StrOutputParser()
    )
    retrieval_chain = generate_docs_for_retrieval | retriever

    prompt = hyde_rag_prompt()
    final_rag_chain = (
        prompt
        | llm
        | StrOutputParser()
    )

    return retrieval_chain, final_rag_chain

retrieval_chain, final_rag_chain = init_models()


st.title("چت بات هوشمند پاسخگویی به سوالات مشتریان")

if "history" not in st.session_state:
    st.session_state.history = []

for msg in st.session_state.history:
    st.chat_message(msg["role"]).markdown(msg["content"])

prompt = st.chat_input("😇 سوال خود را از دستیار هوشمند ما بپرسید")

if prompt:
    st.chat_message("user").markdown(prompt)
    st.session_state.history.append({"role": "user", "content": prompt})

    retrieved_docs = retrieval_chain.invoke({"question": prompt})
    response = final_rag_chain.invoke({"context": retrieved_docs, "question": prompt})

    st.chat_message("assistant").markdown(response)
    st.session_state.history.append({"role": "assistant", "content": response})

    gc.collect()
