김이언
김이언
🏅 AI 마스터
🌿 뉴비 파트너
🌈 지피터스금손

Langchain으로 데이터 시각화하기(+Dash)

Langchain을 기반으로 Dash 프레임워크와 Plotly 라이브러리를 이용하여 데이터 시각화 웹앱을 만들어 보았습니다. 😄

목차

  • 시작하기

  • 데이터 준비

  • 프로그램 흐름과 주요 기능

  • 프로그램 코딩

  • 실행 결과

  • 정리



1. 시작하기

랭체인 기반의 데이터 분석 웹앱을 만들자는 목표를 갖고 시작했습니다.😁
주어진 데이터를 분석하기 위해 사용자의 요구에 맞추어 그래프를 그리고, 그래프를 설명하는 기능을 구현했습니다. 사용자가 분석 결과를 보고 이어서 다음 요청을 하면 이전의 대화를 기억하여 후속 작업 이루어지게 하였습니다.


2. 데이터 준비

Kaggle의 Students Performance in Exams 데이터를 사용했습니다.
StudentsPerformance.csv 파일은 학생들의 시험 성적 데이터로, 1000개의 레코드와 다음 8개 컬럼으로 구성되어 있습니다. 시험 준비, 인종, 부모의 배경 등이 학생의 성적에 미치는 영향 등을 분석할 수 있습니다.

  • gender: 성별

  • race/ethnicity: 인종/민족

  • parental level of education: 부모 학력 수준

  • lunch: 점심식사 여부

  • test preparation course: 시험 준비학습 여부

  • math score: 수학 성적

  • reading score: 읽기 성적

  • writing score: 쓰기 성적


3. 프로그램 흐름과 주요 기능

주요 기능

  • 데이터: csv 파일을 데이터 프레임으로 만들고, 프롬프트에서 사용하기 위해 문자열로 변환한다.

  • 메모리: ConversationSummaryBufferMemory를 설정하여 이전 대화를 200 토큰까지 기록한다.

  • 웹브라우저 실행: 상단에 데이터 그리드가 보이고 사용자 입력창이 표시된다.

  • 입력 처리: 사용자가 요청을 입력하면 메모리가 추가된 체인을 실행하여 모델의 응답을 반환한다.

  • 출력 처리: 위 결과에서 코드를 추출하여 plotly를 이용한 그래프를 생성하고, 그래프 설명을 작성한다. Dash 앱에서 설정한 Layout에 맞추어 웹 페이지를 구성한다.


4. 프로그램 코딩

  • 환경변수와 LLM 모델을 설정합니다.

  • CSV 파일에서 데이터(테스트로 상위 100개 사용)를 가져옵니다.

import os 
from langchain_openai import ChatOpenAI 
from langchain_core.prompts import ChatPromptTemplate  
from langchain.prompts.chat import MessagesPlaceholder 
from langchain.memory import ConversationSummaryBufferMemory 
from langchain.schema.runnable import RunnablePassthrough  

from dash import Dash, html, dcc, callback, Output, Input, State 
import dash_ag_grid as dag 
import pandas as pd  
import re  # 정규 표현식 사용

# 환경 변수 설정

# CSV 파일에서 데이터 로드 및 처리
df = pd.read_csv('StudentsPerformance.csv')  # CSV 파일 읽기
df_select = df.head(100)  # 상위 100개 행 선택
csv_string = df_select.to_string(index=False)  # 프롬프트에서 사용하기 위해 데이터프레임을 문자열로 변환

# ChatGPT 모델 설정
llm = ChatOpenAI(model="gpt-4o", temperature=0)  
  • 대화 기록을 위한 메모리를 설정하고, 메모리 로드 함수를 정의합니다.

# 대화 요약 메모리 설정
memory = ConversationSummaryBufferMemory(
   llm=llm, # 대화 요약에 사용할 모델 지정
   max_token_limit=200, # 메모리에 저장할 최대 토큰 수, 초과시 이전 것부터 삭제됨 
   memory_key="chat_history", # 대화 기록을 저장할 키
   return_messages=True # 메모리에서 전체 메시지를 반환하도록 설정
)

# 메모리 로드 함수 정의
def load_memory(input):
   return memory.load_memory_variables({})["chat_history"] #  메모리에서 대화 기록 변수를 불러와 반환
  • 프롬프트 템플릿을 정의, 그래프 라이브러리로 Plotly를 사용하고 출력 형태를 따르도록 지시합니다.

  • 메모리가 추가된 체인을 설정하고, 체인 실행 함수를 정의합니다.

# 프롬프트 템플릿 정의
prompt = ChatPromptTemplate.from_messages([
   # 시스템 메시지, 데이터 시각화 전문가 역할
   (
       "system",
       "You're a data visualization expert using Plotly, a graphing library. "
       "The data to analyze is provided as a StudentsPerformance.csv file. Here are the selected data set: {csv_string}"
       "Follow the user's indications when creating the graph. Do not include fig.show() in your code. "
       "After providing the code for your graph, please write a structured and detailed description of your graph in Korean. "
       "Start your description with '### 그래프 설명' on a new line after the code block. "
       "Organize your explanation into the following categories, each starting with a subtitle in bold:\n"
       "**1. 주요 통계:**\n"
       "**2. 데이터 분포:**\n"
       "**3. 주목할 만한 점:**\n"
       "**4. 결론 및 인사이트:**\n"
       "Under each category, present information as bullet points, with each point starting with '• ' and followed by a line break. "
       "Limit your explanation to 2-3 bullet points per category, prioritizing the most important information. "
       "Include specific numeric data in your explanation, such as averages, maximums, minimums, or other relevant statistics. "
       "In the conclusion, directly address the user's specific question or request about the data. "
       "Provide insights that are relevant to the user's query and the visualized data. "
       "Ensure that your description accurately reflects the data shown in the generated graph and answers the user's question. "
   ),
   MessagesPlaceholder(variable_name="chat_history"),  # 대화 기록 플레이스홀더
   ("human", "{question}"),  # 사용자 질문 플레이스홀더
])

# RunnablePassthrough를 사용해 체인 설정, 대화 기록 불러오기
chain = RunnablePassthrough.assign(chat_history=load_memory) | prompt | llm

# 체인 실행 함수 정의
def invoke_chain(question):
   result = chain.invoke({"question": question, "csv_string": csv_string}) # question, csv_string를 체인에 전달, 결과 반환
   memory.save_context( # 메모리에 입력된 질문과 체인 결과 저장
       {"input": question},
       {"output": result.content},
   )
   return result.content
  • Dash 앱을 초기화하고, 화면 레이아웃을 설정합니다.

# Dash 초기화 및 레이아웃 설정
app = Dash()  
app.layout = html.Div([
   # 헤더 설정
   html.H2([
       html.Img(src=app.get_asset_url('chart.gif'), style={'height': '30px', 'marginRight': '10px', 'verticalAlign': 'middle'}),
       "AI 기반 데이터 시각화 도구"
   ], style={'marginBottom': '10px'}),    

   # 데이터 그리드 표시
   html.Div([
       dag.AgGrid(
           rowData=df_select.to_dict("records"),
           columnDefs=[{"field": i} for i in df_select.columns],
           defaultColDef={"filter": True, "sortable": True, "floatingFilter": True},
           style={'height': '30vh', 'width': '100%'}
       )
   ], style={'marginBottom': '20px'}),

  # 사용자 입력 영역
  html.Strong("데이터에서 무엇을 보고 싶으신가요?"),
  dcc.Textarea(id='user-request', style={'width': '100%', 'height': '30px', 'margin-top': '5px'}),
  html.Br(),
  html.Button('Submit', id='my-button'),
  
  # 결과 표시 영역
  dcc.Loading(
      [
           html.Div(id='my-figure', style={'marginTop': '20px', 'marginBottom': '20px'}),
           dcc.Markdown(id='content', children='', style={'whiteSpace': 'pre-wrap', 'wordBreak': 'break-word'})
      ],
      type='default'
  )  
], style={ # 여백을 위한 CSS 설정
   'margin-left': 'auto', 
   'margin-right': 'auto',
   'max-width': '1600px',
   'padding': '0 20px',  
   'paddingBottom': '30px'
})
  • 그래프 객체 추출/그래프 그리기/콜백 함수를 정의합니다.

  • 정규 표현식으로 코드를 추출하여 그래프를 그리고, 그래프 설명을 출력합니다.

# 코드에서 그래프 객체 추출 함수 정의
def get_fig_from_code(code):
   local_variables = {}
   exec(code, globals(), local_variables)
   return local_variables['fig']

# 콜백 함수 정의
@callback(
   Output('my-figure', 'children'), # 콜백 함수가 반환할 첫 번째 출력       
   Output('content', 'children'), # 두 번째 출력
   Input('my-button', 'n_clicks'), # submit 버튼을 누르면 콜백 함수가 반응
   State('user-request', 'value'), # 콜백 함수는 사용자가 입력한 값을 사용
   prevent_initial_call=True # 초기 페이지 로드 시 콜백 함수 실행 방지
)
def create_graph(_, user_input):
   result_output = invoke_chain(user_input) # 체인 실행 결과 반환
   # print(result_output)  # 디버깅용 출력

   # 코드 블록 추출 및 처리
   code_block_match = re.search(r'```(?:python)?(.*?)```', result_output, re.DOTALL | re.IGNORECASE)
   if code_block_match:
       code_block = code_block_match.group(1).strip() # 코드 블록을 변수에 저장
       cleaned_code = re.sub(r'(?m)^\s+fig\.show\(\)\s+$', '', code_block) # 코드 블록에서 불필요한 부분을 제거
       explanation = re.sub(r'```(?:python)?.*?```', '', result_output, flags=re.DOTALL | re.IGNORECASE).strip() # 설명 부분을 추출

       fig = get_fig_from_code(cleaned_code) # 추출한 코드 블록을 실행하여 그래프 객체 생성
       return_value = (dcc.Graph(figure=fig), explanation) # 그래프와 설명을 반환값으로 설정
       return return_value
   
   else:
       return "", "그래프를 생성할 수 없습니다."
   
# 메인 실행 부분
if __name__ == '__main__':
  app.run_server(debug=False, port=8887)  # 디버그 모드 비활성화, 포트 8887에서 서버 실행

5. 실행결과

  • “쓰기 점수와 읽기 점수의 관계를 보여줘”를 입력하고 실행한 결과

  • 이어서 “성별 변수를 추가해서 다시 그려줘”라고 요청한 결과


6. 정리

  • 간단한 구성이라고 생각했는데 시간이 많이 걸렸습니다.🤣 당연하게도 코드와 그 내용을 이해하는 것이 가장 중요하다고 다시 한번 느꼈어요!

  • 단순한 앱이지만 랭체인을 이용하여 작동되는 무언가를 만드는 것이 의미있었습니다. AI를 모르는, 더구나 랭체인은 들어본 적도 없는 주변인들이 웹앱에는 반응을 하며, 이거 저거가 되느냐고 질문합니다.(만들어 줄 수 없는 것이 함정🥶)

  • Dash 제공 기능을 활용하여 데이터를 앱에서 직접 업로드 하도록 변경하고, 화면 구성과 분석 기능을 추가하면 보다 발전된 앱이 되리라 생각합니다~


#11기랭체인

14
6개의 답글

👉 이 게시글도 읽어보세요