Chatbot

Joined
Sep 21, 2024
Messages
2
Reaction score
0
How do I search the database if the myfile.txt file doesn't have the answer in this c++ code?
C++:
#include <iostream>
#include <fstream>
#include <sstream>
#include <string>
#include <vector>
#include <unordered_map>
#include <unordered_set>
#include <algorithm>
#include <numeric>
#include <cmath>
#include <random>
#include <deque>
#include <tuple>

// Typedef pour les vecteurs
using Vector = std::vector<double>;
using WordEmbeddings = std::unordered_map<std::string, Vector>;

// Classe pour la mémoire de replay
class ReplayMemory {
public:
    ReplayMemory(size_t capacity) : m_capacity(capacity) {}

    void add(const std::tuple<int, int, double, int>& experience) {
        if (m_memory.size() >= m_capacity) {
            m_memory.pop_front(); // Supprimer la plus ancienne expérience si la capacité est atteinte
        }
        m_memory.push_back(experience);
    }

    std::vector<std::tuple<int, int, double, int>> sample(size_t batch_size) {
        std::vector<std::tuple<int, int, double, int>> samples;
        std::sample(m_memory.begin(), m_memory.end(), std::back_inserter(samples),
                     batch_size, std::mt19937{std::random_device{}()});
        return samples;
    }

    bool isEmpty() const {
        return m_memory.empty();
    }

private:
    size_t m_capacity;
    std::deque<std::tuple<int, int, double, int>> m_memory; // (état, action, récompense, état suivant)
};

// Classe pour Q-Learning
class QLearning {
public:
    QLearning(int numStates, int numActions, double gamma, double alpha, size_t replayMemorySize) :
        m_numStates(numStates), m_numActions(numActions), m_gamma(gamma), m_alpha(alpha),
        m_replayMemory(replayMemorySize) {
        m_qTable.resize(m_numStates, std::vector<double>(m_numActions, 0.0));
    }

    void addExperience(int state, int action, double reward, int nextState) {
        m_replayMemory.add(std::make_tuple(state, action, reward, nextState));
    }

    void learnFromReplay(size_t batch_size) {
        if (m_replayMemory.isEmpty()) return;

        auto experiences = m_replayMemory.sample(batch_size);
        for (const auto& experience : experiences) {
            int state, action, nextState;
            double reward;
            std::tie(state, action, reward, nextState) = experience;

            double maxNextQValue = *std::max_element(m_qTable[nextState].begin(), m_qTable[nextState].end());
            m_qTable[state][action] = (1 - m_alpha) * m_qTable[state][action] + m_alpha * (reward + m_gamma * maxNextQValue);
        }
    }

    int chooseAction(int state, double epsilon) {
        if (std::rand() / static_cast<double>(RAND_MAX) < epsilon) {
            return std::rand() % m_numActions; // Exploration
        } else {
            return std::distance(m_qTable[state].begin(), std::max_element(m_qTable[state].begin(), m_qTable[state].end())); // Exploitation
        }
    }

    void getQValue() {
        std::cout << "Final Q-Value Table:" << std::endl;
        for (int state = 0; state < m_numStates; ++state) {
            std::cout << "State " << state << ": ";
            for (int action = 0; action < m_numActions; ++action) {
                std::cout << m_qTable[state][action] << " ";
            }
            std::cout << std::endl;
        }
    }

    // Sauvegarder les valeurs Q dans un fichier
    void saveQValues(const std::string& filename) {
        std::ofstream outFile(filename);
        if (outFile.is_open()) {
            for (const auto& state : m_qTable) {
                for (const auto& qValue : state) {
                    outFile << qValue << " ";
                }
                outFile << std::endl;
            }
            outFile.close();
        } else {
            std::cerr << "Error opening file for saving Q-values." << std::endl;
        }
    }

    // Charger les valeurs Q depuis un fichier
    void loadQValues(const std::string& filename) {
        std::ifstream inFile(filename);
        if (inFile.is_open()) {
            for (auto& state : m_qTable) {
                for (auto& qValue : state) {
                    inFile >> qValue;
                }
            }
            inFile.close();
        } else {
            std::cerr << "Error opening file for loading Q-values." << std::endl;
        }
    }

private:
    int m_numStates;
    int m_numActions;
    double m_gamma;
    double m_alpha;
    std::vector<std::vector<double>> m_qTable;
    ReplayMemory m_replayMemory; // Instance de la mémoire de replay
};

// Classe pour représenter un nœud dans le réseau bayésien
class Node {
public:
    std::string name;
    std::vector<Node*> parents;
    std::vector<Node*> children;
    std::unordered_map<std::string, double> distribution;

    Node(const std::string& name) : name(name) {}
};

// Classe pour représenter le réseau bayésien
class BayesianNetwork {
public:
    std::vector<Node*> nodes;

    ~BayesianNetwork() {
        for (auto node : nodes) {
            delete node;
        }
    }

    void addNode(Node* node) {
        nodes.push_back(node);
    }

    void addEdge(const std::string& parent, const std::string& child) {
        Node* parentNode = getNode(parent);
        Node* childNode = getNode(child);
        if (parentNode && childNode) {
            parentNode->children.push_back(childNode);
            childNode->parents.push_back(parentNode);
        }
    }

    Node* getNode(const std::string& name) {
        for (auto& node : nodes) {
            if (node->name == name) {
                return node;
            }
        }
        return nullptr;
    }

    double getConditionalProbability(const std::string& target, const std::unordered_map<std::string, std::string>& evidence) {
        Node* targetNode = getNode(target);
        if (!targetNode) return 0.0;

        double jointProbability = calculateJointProbability(targetNode, evidence);
        return jointProbability; // Normaliser si nécessaire
    }

private:
    double calculateJointProbability(Node* targetNode, const std::unordered_map<std::string, std::string>& evidence) {
        double jointProbability = targetNode->distribution[targetNode->name];

        for (const auto& [var, value] : evidence) {
            Node* evidenceNode = getNode(var);
            if (evidenceNode) {
                jointProbability *= evidenceNode->distribution.at(value);
            }
        }
        return jointProbability;
    }
};

// Fonction pour diviser une chaîne en mots
std::vector<std::string> split(const std::string& s) {
    std::vector<std::string> tokens;
    std::istringstream iss(s);
    std::string token;
    while (iss >> token) {
        tokens.push_back(token);
    }
    return tokens;
}

// Fonction pour supprimer les mots vides
std::string remove_stop_words(const std::string& text, const std::unordered_set<std::string>& stop_words) {
    std::stringstream ss;
    std::string word;
    for (const auto& w : split(text)) {
        if (stop_words.find(w) == stop_words.end()) {
            ss << w << " ";
        }
    }
    std::string result = ss.str();
    if (!result.empty()) {
        result.pop_back();
    }
    return result;
}

// Fonction pour générer un embedding pour un mot
Vector generateEmbedding(const std::string& word, const Vector& alpha, const Vector& beta_param) {
    std::random_device rd;
    std::mt19937 gen(rd());
    std::normal_distribution<double> distribution(0.0, 1.0);

    Vector embedding(alpha.size());
    for (size_t i = 0; i < embedding.size(); i++) {
        embedding[i] = alpha[i] + beta_param[i] * distribution(gen);
    }

    return embedding;
}

// Fonction pour obtenir les embeddings de mots
WordEmbeddings getWordEmbeddings(const std::unordered_map<std::string, std::unordered_map<std::string, int>>& dialogues, const Vector& alpha, const Vector& beta_param) {
    WordEmbeddings wordEmbeddings;
    for (const auto& dialogue : dialogues) {
        for (const auto& response : dialogue.second) {
            for (const auto& word : split(response.first)) {
                if (wordEmbeddings.find(word) == wordEmbeddings.end()) {
                    wordEmbeddings[word] = generateEmbedding(word, alpha, beta_param);
                }
            }
        }
    }
    return wordEmbeddings;
}

// Fonction pour calculer la similarité cosinus
double computeCosineSimilarity(const Vector& v1, const Vector& v2) {
    double dotProduct = 0.0;
    double norm1 = 0.0;
    double norm2 = 0.0;

    for (size_t i = 0; i < v1.size(); i++) {
        dotProduct += v1[i] * v2[i];
        norm1 += v1[i] * v1[i];
        norm2 += v2[i] * v2[i];
    }

    norm1 = std::sqrt(norm1);
    norm2 = std::sqrt(norm2);

    if (norm1 == 0.0 || norm2 == 0.0) {
        return 0.0;
    }

    return dotProduct / (norm1 * norm2);
}

// Fonction pour calculer la similarité entre une question et une réponse
double computeSimilarity(const std::string& question, const std::string& response, const WordEmbeddings& wordEmbeddings) {
    std::vector<std::string> questionWords = split(question);
    std::vector<std::string> responseWords = split(response);

    double similarity = 0.0;
    int numMatchingWords = 0;

    for (const auto& questionWord : questionWords) {
        for (const auto& responseWord : responseWords) {
            if (wordEmbeddings.find(questionWord) != wordEmbeddings.end() && wordEmbeddings.find(responseWord) != wordEmbeddings.end()) {
                Vector questionWordEmbedding = wordEmbeddings.at(questionWord);
                Vector responseWordEmbedding = wordEmbeddings.at(responseWord);

                double cosineSimilarity = computeCosineSimilarity(questionWordEmbedding, responseWordEmbedding);
                similarity += cosineSimilarity;
                numMatchingWords++;
            }
        }
    }

    if (numMatchingWords > 0) {
        similarity /= numMatchingWords;
    } else {
        similarity = 0.0;
    }

    return similarity;
}

// Fonction pour extraire une sous-chaîne entre deux délimiteurs
std::string get_str_between_two_str(const std::string &s,
                                     const std::string &start_delim,
                                     const std::string &stop_delim) {
    unsigned first_delim_pos = s.find(start_delim);
    if (first_delim_pos == std::string::npos) return ""; // Délimiteur de début non trouvé
    unsigned end_pos_of_first_delim = first_delim_pos + start_delim.length();
    unsigned last_delim_pos = s.find(stop_delim, end_pos_of_first_delim);
    if (last_delim_pos == std::string::npos) return ""; // Délimiteur de fin non trouvé

    return s.substr(end_pos_of_first_delim, last_delim_pos - end_pos_of_first_delim);
}
// Fonction pour calculer le TF-IDF d'un terme dans un fichier
std::pair<std::string, double> compute_tfidf(const std::string& token, const std::string& file_name, int word_count, double avg_doc_len, double doc_len_correction) {
    std::unordered_map<std::string, double> phrase_selection;
    std::ifstream ifs(file_name);
    if (!ifs) {
        std::cerr << "Error opening file: " << file_name << std::endl;
        return {"", 0.0};
    }

    std::string line;
    while (std::getline(ifs, line)) {
        // Extraire les phrases de la ligne
        std::string::size_type start = 0;
        while (true) {
            std::string::size_type end = line.find_first_of(".!?", start);
            if (end == std::string::npos) break; // Pas de fin de phrase trouvée

            std::string phrase = line.substr(start, end - start + 1); // Inclure le délimiteur
            if (phrase.find(token) != std::string::npos) {
                std::vector<std::string> words = split(phrase);
                int n = words.size();
                double tf = std::count(words.begin(), words.end(), token) / static_cast<double>(word_count);
                double idf = std::log((n - tf + 0.5) / (tf + 0.5));
                double tf_idf = tf / (tf + doc_len_correction) * idf;
                phrase_selection[phrase] = tf_idf;
            }
            start = end + 1; // Passer au début de la prochaine phrase
        }
    }

    if (phrase_selection.empty()) return {"", 0.0};

    auto best = std::max_element(phrase_selection.begin(), phrase_selection.end(),
                                [](const auto& a, const auto& b) { return a.second < b.second; });
    return {best->first, best->second};
}

// Apprentissage de la structure du BN
BayesianNetwork structureLearning(const std::vector<std::string>& variables) {
    BayesianNetwork BN;

    for (const auto& var : variables) {
        BN.addNode(new Node(var));
    }

    // Exemple de relations entre nœuds
    BN.addEdge("A", "B");
    BN.addEdge("A", "C");

    return BN;
}

// Apprentissage des paramètres du BN
void parameterLearning(BayesianNetwork& BN, const std::vector<std::string>& data) {
    for (auto& node : BN.nodes) {
        double count = 0.0;
        for (const auto& datum : data) {
            if (datum == node->name) {
                count++;
            }
        }
        node->distribution[node->name] = count / data.size();
    }
}

// Déterminer l'état basé sur l'entrée utilisateur
int determineState(const std::string& user_input) {
    return user_input.length() % 10; // Exemple : 10 états
}

// Fonction pour évaluer la récompense
double getReward(int action, const std::string& user_input, double tfidf_score) {
    // Exemple : la récompense est pondérée par le score TF-IDF
    return static_cast<double>(std::rand() % 10) * tfidf_score; // Ajustez la formule selon vos besoins
}

// Déterminer l'état suivant basé sur l'action
int determineNextState(int action) {
    return action % 10; // Exemple : 10 états
}

class Database {
private:
    std::unordered_map<int, std::string> data; // Utilisation d'un entier comme clé
    const std::string filename = "database.txt";
    int nextKey; // Pour garder la trace de la prochaine clé à utiliser

    void loadFromFile() {
        std::ifstream file(filename);
        if (!file) {
            std::cerr << "Error opening file for reading." << std::endl;
            return;
        }
        int key;
        std::string value;
        while (file >> key) {
            std::getline(file, value); // Lire la ligne entière après la clé
            if (!value.empty() && value[0] == ' ') {
                value.erase(0, 1); // Supprimer l'espace initial
            }
            data[key] = value;
            nextKey = std::max(nextKey, key + 1); // Mettre à jour nextKey
        }
        file.close();
    }

    void saveToFile() {
    std::ofstream file(filename, std::ios_base::trunc); // Ouvrir en mode écriture, écraser le contenu
    if (!file) {
        std::cerr << "Error opening file for writing." << std::endl;
        return;
    }
    for (const auto& pair : data) {
        file << pair.first << " " << pair.second << "\n";
    }
    file.close();
}

public:
    Database() : nextKey(1) { // Initialiser nextKey à 1
        loadFromFile();
    }

    ~Database() {
        saveToFile();
    }

    void add(const std::string& value) {
        data[nextKey] = value; // Utiliser nextKey comme clé
        std::cout << "Added: " << nextKey << " -> " << value << std::endl;
        nextKey++; // Incrémenter nextKey pour la prochaine entrée
    }

    void remove(int key) {
        if (data.erase(key)) {
            std::cout << "Removed: " << key << std::endl;
        } else {
            std::cout << "Key not found: " << key << std::endl;
        }
    }

    void search(int key) {
        auto it = data.find(key);
        if (it != data.end()) {
            std::cout << "Found: " << it->first << " -> " << it->second << std::endl;
        } else {
            std::cout << "Key not found: " << key << std::endl;
        }
    }

    void searchByText(const std::string& token) {
        std::unordered_map<int, double> line_selection;
        for (const auto& pair : data) {
            if (pair.second.find(token) != std::string::npos) {
                line_selection[pair.first] = 1.0; // On peut ajuster le score ici si nécessaire
            }
        }

        if (!line_selection.empty()) {
            for (const auto& pair : line_selection) {
                std::cout << "Found in key: " << pair.first << " -> " << data[pair.first] << std::endl;
            }
        } else {
            std::cout << "No entries found containing: " << token << std::endl;
        }
    }
};
// Fonction principale
int main() {
     Database db;
    std::string file_name = "myfile.txt";
    std::string qValuesFile = "q_values.txt"; // Fichier pour sauvegarder les valeurs Q

    std::ifstream inFile(file_name);
    if (!inFile) {
        std::cerr << "Error opening file: " << file_name << std::endl;
        return 1;
    }

    std::unordered_set<std::string> stop_words =  { "a", "about", "above", "after", "again", "against", "all", "am", "an", "and", "any", "are", "aren't", "as", "at", "be", "because",
        "been", "before", "being", "below", "between", "both", "but", "by", "can't", "cannot", "could", "couldn't", "did", "didn't", "do", "does", "doesn't", "doing", "don't",
        "down", "during", "each", "few", "for", "from", "further", "had", "hadn't", "has", "hasn't", "have", "haven't", "having", "he", "he'd", "he'll", "he's", "her", "here",
        "here's", "hers", "herself", "him", "himself", "his", "how", "how's", "i", "i'd", "i'll", "i'm", "i've", "if", "in", "into", "is", "isn't", "it", "it's", "its", "itself",
        "let's", "me", "more", "most", "mustn't", "my", "myself", "no", "nor", "not", "of", "off", "on", "once", "only", "or", "other", "ought", "our", "ours", "ourselves", "out",
        "over", "own", "same", "shan't", "she", "she'd", "she'll", "she's", "should", "shouldn't", "so", "some", "such", "than", "that", "that's", "the", "their", "theirs",
        "them", "themselves", "then", "there", "there's", "these", "they", "they'd", "they'll", "they're", "they've", "this", "those", "through", "to", "too", "under", "until",
        "up", "very", "was", "wasn't", "we", "we'd", "we'll", "we're", "we've", "were", "weren't", "what", "what's", "when", "when's", "where", "where's", "which", "while",
        "who", "who's", "whom", "why", "why's", "with", "won't", "would", "wouldn't", "you", "you'd", "you'll", "you're", "you've", "your", "yours", "yourself", "yourselves" };

    // Initialiser l'agent Q-Learning
    const int numStates = 10; // Augmenter le nombre d'états
    const int numActions = 5; // Augmenter le nombre d'actions
    double gamma = 0.9; // Facteur de discount
    double alpha = 0.1; // Taux d'apprentissage
    size_t replayMemorySize = 1000; // Taille de la mémoire de replay
    QLearning agent(numStates, numActions, gamma, alpha, replayMemorySize);

    // Charger les valeurs Q depuis le fichier
    agent.loadQValues(qValuesFile);

    // Paramètres pour les embeddings
    Vector alphaParams = {0.5, 0.5}; // Exemple de paramètres alpha
    Vector beta_param = {0.1, 0.1}; // Exemple de paramètres beta

    std::string user_input;
    while (true) {
        std::cout << "Please enter your question (or type 'exit' to quit): ";
        std::getline(std::cin, user_input);

        if (user_input == "exit") {
            break; // Sortir de la boucle si l'utilisateur tape 'exit'
        }

        std::string filtered_input = remove_stop_words(user_input, stop_words);

        // Calculer les statistiques du fichier
        int word_count = 0;
        std::vector<double> line_lengths;
        std::string line;

        inFile.clear();
        inFile.seekg(0, std::ios::beg);

        while (std::getline(inFile, line)) {
            std::vector<std::string> words = split(line);
            word_count += words.size();
            line_lengths.push_back(line.length());
        }

        double avg_doc_len = line_lengths.empty() ? 0.0 : std::accumulate(line_lengths.begin(), line_lengths.end(), 0.0) / line_lengths.size();
        double doc_len_correction = 0.75 * word_count / avg_doc_len + 0.25;
        
        // Définir le réseau bayésien
        BayesianNetwork BN = structureLearning(split(filtered_input));

        // Apprentissage des paramètres
        parameterLearning(BN, split(filtered_input));

        // Obtenir les embeddings de mots
        std::unordered_map<std::string, std::unordered_map<std::string, int>> dialogues; // Remplir avec des dialogues
        WordEmbeddings wordEmbeddings = getWordEmbeddings(dialogues, alphaParams, beta_param);
        
        // Calculer les scores TF-IDF
        std::unordered_map<std::string, double> tfidf_scores;
        for (const auto& token : split(filtered_input)) {
            auto result = compute_tfidf(token, file_name, word_count, avg_doc_len, doc_len_correction);
            tfidf_scores[token] = result.second;
        }

        // Renforcer les scores TF-IDF avec les probabilités conditionnelles
        for (const auto& word1 : split(filtered_input)) {
            for (const auto& word2 : split(filtered_input)) {
                if (word1 != word2) {
                    std::unordered_map<std::string, std::string> evidence; // Remplir avec des preuves si nécessaire
                    double prob = BN.getConditionalProbability(word1, evidence);
                    tfidf_scores[word1] *= prob; // Renforcer le score TF-IDF
                }
            }
        }

        // Trouver le mot clé avec le meilleur score
        std::string best_token;
        double best_score = 0.0;
        if (!tfidf_scores.empty()) {
            auto best_it = std::max_element(tfidf_scores.begin(), tfidf_scores.end(),
                                             [](const auto& a, const auto& b) { return a.second < b.second; });
            best_token = best_it->first;
            best_score = best_it->second;

            std::cout << "Let's talk about \"" << best_token << "\"..." << std::endl;
            std::string goodSentence = compute_tfidf(best_token, file_name, word_count, avg_doc_len, doc_len_correction).first;
            std::cout << "Best line: " << compute_tfidf(best_token, file_name, word_count, avg_doc_len, doc_len_correction).first << std::endl;
            db.add(goodSentence);
            // Calculer la similarité avec les embeddings
            double similarity = computeSimilarity(user_input, best_token, wordEmbeddings);
            std::cout << "Similarity score: " << similarity << std::endl;

            // Afficher tous les scores TF-IDF pour le débogage
            for (const auto& [token, score] : tfidf_scores) {
                std::cout << "Token: " << token << ", TF-IDF Score: " << score << std::endl;
            }
        } else {
            std::cout << "No relevant information found in the file." << std::endl;
        }

        // Exemple d'utilisation de l'agent Q-Learning
        double exploreRate = 1.0; // Taux d'exploration initial
        int state = determineState(user_input); // Déterminer l'état basé sur l'entrée utilisateur

        // Boucle d'apprentissage
        for (int episode = 0; episode < 10; ++episode) { // Nombre d'épisodes d'apprentissage
            // Choisir une action
            int action = agent.chooseAction(state, exploreRate);
            double reward = 0.0;

            // Utiliser le score TF-IDF pour la récompense
            if (!tfidf_scores.empty()) {
                reward = getReward(action, user_input, best_score); // Passer le score TF-IDF
            } else {
                reward = getReward(action, user_input, 0.0); // Si aucun score TF-IDF, récompense nulle
            }

            int nextState = determineNextState(action); // Déterminer l'état suivant

            // Ajouter l'expérience à la mémoire de replay
            agent.addExperience(state, action, reward, nextState);

            // Mettre à jour la table Q à partir de la mémoire de replay
            agent.learnFromReplay(32); // Exemple : utiliser un batch de 32 expériences

            // Passer à l'état suivant
            state = nextState;

            // Réduire le taux d'exploration
            exploreRate *= 0.99; // Réduction progressive du taux d'exploration
        }

        // Afficher la table Q finale
        agent.getQValue();
    }

    // Sauvegarder les valeurs Q dans le fichier
    agent.saveQValues(qValuesFile);

    inFile.close(); // Fermer le fichier ici
    return 0;
}
 

Ask a Question

Want to reply to this thread or ask your own question?

You'll need to choose a username for the site, which only take a couple of moments. After that, you can post your question and our members will help you out.

Ask a Question

Members online

No members online now.

Forum statistics

Threads
474,056
Messages
2,570,440
Members
47,101
Latest member
DoloresHol

Latest Threads

Top