import streamlit as st

from langchain_ollama.chat_models import ChatOllama
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough

import whisper
import random
import gc

from utils.embedding_model import init_embedding
from utils.vector_store import get_vector_store
from prompts.rag import improve_rag_prompt

# RTL support for Persian text
st.markdown("""
<style>
    html, body, [class*="css"] {
        direction: rtl;
        text-align: right;
    }
    button, input, optgroup, select, textarea {
        direction: rtl;
    }
    .stButton button {
        direction: rtl;
    }
    header {
        direction: rtl;
    }
    .main {
        direction: rtl;
    }
    .stChatInput {
        direction: rtl;
    }
</style>
""", unsafe_allow_html=True)

@st.cache_resource
def init_models():
    embedding_model, _ = init_embedding(
        model_name="intfloat/multilingual-e5-large",
        model_kwargs = {"device": "cpu", "trust_remote_code": True},
        test_query = "Embedding Model Initialized Successfully.")

    retriever = get_vector_store(persist_directory="db/store_500", embedding_model=embedding_model, k_arg=1)
    # retriever = get_vector_store_list(embedding_model=embedding_model, persist_directory="src/db/store_manually",
    #                                 collection_name="welfare_faq", k_arg=1)

    # llm = init_llm(model_name="gemma2:9b", num_predict=256, temperature=0.0)
    llm = ChatOllama(
        model="gemma2:27b",
        temperature=0.0,
        verbose=True,
        num_predict=256,
    )
    prompt = improve_rag_prompt()

    rag_chain = (
        {"context": retriever, "question": RunnablePassthrough()}
        | prompt
        | llm
        | StrOutputParser()
    )

    asr_model = whisper.load_model("large-v3-turbo")
    # asr_model = whisper.load_model("tiny")

    return rag_chain, asr_model

rag_chain, asr_model = init_models()


st.title("چت بات هوشمند پاسخگویی به سوالات مشتریان")

if "history" not in st.session_state:
    st.session_state.history = []

for msg in st.session_state.history:
    st.chat_message(msg["role"]).markdown(msg["content"])


audio = st.audio_input("🤩 با صدای خودتان سوال بپرسید")
prompt = st.chat_input("😇 سوال خود را از دستیار هوشمند ما بپرسید")

if prompt:
    st.chat_message("user").markdown(prompt)
    st.session_state.history.append({"role": "user", "content": prompt})

    response = rag_chain.invoke(prompt)

    st.chat_message("assistant").markdown(response)
    st.session_state.history.append({"role": "assistant", "content": response})

    gc.collect()

if audio:
    st.chat_message("assistant").markdown("... پردازش صوت")

    try:
        sample_name = str(random.randint(1, 100000))+"_test.wav"
    except Exception :
        sample_name = str(random.randint(1, 100000))+"_test.wav"

    with open(f"audio/{sample_name}", "wb") as f:
        f.write(audio.getbuffer())

    result = asr_model.transcribe(f"audio/{sample_name}")

    st.chat_message("user").markdown(result["text"])
    st.session_state.history.append({"role": "user", "content": result["text"]})

    # ToDo: Error handling if not Persian
    response = rag_chain.invoke(result["text"])

    st.chat_message("assistant").markdown(response)
    st.session_state.history.append({"role": "assistant", "content": response})

    gc.collect()
    # st.rerun()
