financial-assistant

Форк
0
92 строки · 3.9 Кб
1
import csv, numpy as np
2
from sentence_transformers import SentenceTransformer
3
from ..models import Bank, LoanDetailedDescription, ProductCategories
4
from .preprocessing import preprocessing
5

6
model_id = 'intfloat/multilingual-e5-base'
7
model = SentenceTransformer(model_id) 
8
prepocessed_db_texts = []
9
prepocessed_db_texts_idx = []
10

11
def load_db_texts_from_csv(bank_id, product_id, filename='db_texts.csv'):
12
  db_texts = []
13
  db_texts_idx = []
14
  try:
15
    with open(filename, 'r', encoding='utf-8') as csvfile:
16
      reader = csv.DictReader(csvfile)
17
      if bank_id is not None and product_id is not None:
18
        for idx, row in enumerate(reader):
19
          if(int(row['bank_id']) == bank_id and int(row['product_id']) == product_id):
20
            db_texts.append(row['description'])
21
            db_texts_idx.append(idx)
22
            
23
      elif bank_id is not None and product_id is None:
24
        for idx, row in enumerate(reader):
25
          if(int(row['bank_id']) == bank_id):
26
            db_texts.append(row['description'])
27
            db_texts_idx.append(idx)
28
            
29
      elif bank_id is None and product_id is not None:
30
        for idx, row in enumerate(reader):
31
          if(int(row['product_id']) == product_id):
32
            db_texts.append(row['description'])
33
            db_texts_idx.append(idx)
34
      else:
35
         for row in reader:
36
            db_texts.append(row['description'])
37
            db_texts_idx.append(idx)
38
        # db_texts.append(f"{row['bank_id']}|{row['product_id']}|{row['category_id']}|{row['preprocessed']}|{row['description']}")
39
  except FileNotFoundError:
40
    pass
41
  return db_texts, db_texts_idx
42

43
def save_db_texts_to_csv(db_texts	, filename='db_texts.csv'):
44
  with open(filename, 'w', newline='', encoding='utf-8') as csvfile:
45
    fieldnames = ['bank_id', 'product_id', 'category_id', 'preprocessed', 'description']
46
    writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
47
    writer.writeheader()
48
    for text in db_texts:
49
      bank_id, product_id, category_id, preprocessed, description = text.split('||')
50
      writer.writerow({'bank_id': bank_id, 'product_id': product_id, 'category_id': category_id, 'preprocessed': preprocessed, 'description': description})
51

52
def get_db_texts(bank_id, product_id):
53
  global prepocessed_db_texts, prepocessed_db_texts_idx
54
  
55
  if not prepocessed_db_texts:
56
    prepocessed_db_texts, prepocessed_db_texts_idx = load_db_texts_from_csv(bank_id, product_id)
57
    
58
    if not prepocessed_db_texts:
59
      loan_descriptions = LoanDetailedDescription.objects.order_by('id').all()
60
      
61
      for desc in loan_descriptions:
62
         bank = Bank.objects.get(id=desc.bank_id_id).nameRus.lower()
63
         category = ProductCategories.objects.get(id=desc.category_id_id).categoryNameRus.lower()
64
         title = desc.title.lower()  
65
         description = desc.description
66
         link = desc.link
67
         preprocessed = preprocessing(f"{description}")
68

69
         prepocessed_db_texts.append(f"{desc.bank_id_id}||{desc.product_id_id}||{desc.category_id_id}||{preprocessed}||{bank}|{category}|{title}|{description}|{link}") 
70
         
71
      save_db_texts_to_csv(prepocessed_db_texts)
72
      prepocessed_db_texts, prepocessed_db_texts_idx = load_db_texts_from_csv(bank_id, product_id)
73
  return prepocessed_db_texts, prepocessed_db_texts_idx
74

75
def load_dense_vectors(indexies):
76
    try:
77
        vecs_texts = np.load('dense_vectors.npy')
78
    except FileNotFoundError:
79
        preprocessed_texts = []
80
        with open('db_texts.csv', 'r', encoding='utf-8') as csvfile:
81
            reader = csv.DictReader(csvfile)
82
            for row in reader:
83
                preprocessed_texts.append(row['preprocessed'])
84

85
        vecs_texts = model.encode(preprocessed_texts, normalize_embeddings=True)
86
        np.save('dense_vectors.npy', vecs_texts) 
87
    
88
    return vecs_texts[indexies]
89

90
def load_texts_by_indices(db_texts, indices):
91
    topK_texts = [db_texts[int(idx)] for idx in indices if int(idx) < len(db_texts)]
92
    return topK_texts 

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.