mirror of
https://gitcode.com/gh_mirrors/eas/EasyFace.git
synced 2026-05-15 03:38:04 +00:00
316 lines
11 KiB
Python
316 lines
11 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||
|
||
import json
|
||
import os
|
||
import random
|
||
import sqlite3
|
||
|
||
from .ontology import all_domains, db_domains
|
||
|
||
|
||
class MultiWozDB(object):
|
||
def __init__(self, db_dir, db_paths):
|
||
self.dbs = {}
|
||
self.sql_dbs = {}
|
||
for domain in all_domains:
|
||
with open(os.path.join(db_dir, db_paths[domain]),
|
||
'r',
|
||
encoding='utf-8') as f:
|
||
self.dbs[domain] = json.loads(f.read().lower())
|
||
|
||
def oneHotVector(self, domain, num):
|
||
"""Return number of available entities for particular domain."""
|
||
vector = [0, 0, 0, 0]
|
||
if num == '':
|
||
return vector
|
||
if domain != 'train':
|
||
if num == 0:
|
||
vector = [1, 0, 0, 0]
|
||
elif num == 1:
|
||
vector = [0, 1, 0, 0]
|
||
elif num <= 3:
|
||
vector = [0, 0, 1, 0]
|
||
else:
|
||
vector = [0, 0, 0, 1]
|
||
else:
|
||
if num == 0:
|
||
vector = [1, 0, 0, 0]
|
||
elif num <= 5:
|
||
vector = [0, 1, 0, 0]
|
||
elif num <= 10:
|
||
vector = [0, 0, 1, 0]
|
||
else:
|
||
vector = [0, 0, 0, 1]
|
||
return vector
|
||
|
||
def addBookingPointer(self, turn_da):
|
||
"""Add information about availability of the booking option."""
|
||
# Booking pointer
|
||
# Do not consider booking two things in a single turn.
|
||
vector = [0, 0]
|
||
if turn_da.get('booking-nobook'):
|
||
vector = [1, 0]
|
||
if turn_da.get('booking-book') or turn_da.get('train-offerbooked'):
|
||
vector = [0, 1]
|
||
return vector
|
||
|
||
def addDBPointer(self, domain, match_num, return_num=False):
|
||
"""Create database pointer for all related domains."""
|
||
# if turn_domains is None:
|
||
# turn_domains = db_domains
|
||
if domain in db_domains:
|
||
vector = self.oneHotVector(domain, match_num)
|
||
else:
|
||
vector = [0, 0, 0, 0]
|
||
return vector
|
||
|
||
def addDBIndicator(self, domain, match_num, return_num=False):
|
||
"""Create database indicator for all related domains."""
|
||
# if turn_domains is None:
|
||
# turn_domains = db_domains
|
||
if domain in db_domains:
|
||
vector = self.oneHotVector(domain, match_num)
|
||
else:
|
||
vector = [0, 0, 0, 0]
|
||
|
||
# '[db_nores]', '[db_0]', '[db_1]', '[db_2]', '[db_3]'
|
||
if vector == [0, 0, 0, 0]:
|
||
indicator = '[db_nores]'
|
||
else:
|
||
indicator = '[db_%s]' % vector.index(1)
|
||
return indicator
|
||
|
||
def get_match_num(self, constraints, return_entry=False):
|
||
"""Create database pointer for all related domains."""
|
||
match = {'general': ''}
|
||
entry = {}
|
||
# if turn_domains is None:
|
||
# turn_domains = db_domains
|
||
for domain in all_domains:
|
||
match[domain] = ''
|
||
if domain in db_domains and constraints.get(domain):
|
||
matched_ents = self.queryJsons(domain, constraints[domain])
|
||
match[domain] = len(matched_ents)
|
||
if return_entry:
|
||
entry[domain] = matched_ents
|
||
if return_entry:
|
||
return entry
|
||
return match
|
||
|
||
def pointerBack(self, vector, domain):
|
||
# multi domain implementation
|
||
# domnum = cfg.domain_num
|
||
if domain.endswith(']'):
|
||
domain = domain[1:-1]
|
||
if domain != 'train':
|
||
nummap = {0: '0', 1: '1', 2: '2-3', 3: '>3'}
|
||
else:
|
||
nummap = {0: '0', 1: '1-5', 2: '6-10', 3: '>10'}
|
||
if vector[:4] == [0, 0, 0, 0]:
|
||
report = ''
|
||
else:
|
||
num = vector.index(1)
|
||
report = domain + ': ' + nummap[num] + '; '
|
||
|
||
if vector[-2] == 0 and vector[-1] == 1:
|
||
report += 'booking: ok'
|
||
if vector[-2] == 1 and vector[-1] == 0:
|
||
report += 'booking: unable'
|
||
|
||
return report
|
||
|
||
def queryJsons(self,
|
||
domain,
|
||
constraints,
|
||
exactly_match=True,
|
||
return_name=False):
|
||
"""Returns the list of entities for a given domain
|
||
based on the annotation of the belief state
|
||
constraints: dict e.g. {'pricerange': 'cheap', 'area': 'west'}
|
||
"""
|
||
# query the db
|
||
if domain == 'taxi':
|
||
return [{
|
||
'taxi_colors':
|
||
random.choice(self.dbs[domain]['taxi_colors']),
|
||
'taxi_types':
|
||
random.choice(self.dbs[domain]['taxi_types']),
|
||
'taxi_phone': [random.randint(1, 9) for _ in range(10)]
|
||
}]
|
||
if domain == 'police':
|
||
return self.dbs['police']
|
||
if domain == 'hospital':
|
||
if constraints.get('department'):
|
||
for entry in self.dbs['hospital']:
|
||
if entry.get('department') == constraints.get(
|
||
'department'):
|
||
return [entry]
|
||
else:
|
||
return []
|
||
|
||
valid_cons = False
|
||
for v in constraints.values():
|
||
if v not in ['not mentioned', '']:
|
||
valid_cons = True
|
||
if not valid_cons:
|
||
return []
|
||
|
||
match_result = []
|
||
|
||
if 'name' in constraints:
|
||
for db_ent in self.dbs[domain]:
|
||
if 'name' in db_ent:
|
||
cons = constraints['name']
|
||
dbn = db_ent['name']
|
||
if cons == dbn:
|
||
db_ent = db_ent if not return_name else db_ent['name']
|
||
match_result.append(db_ent)
|
||
return match_result
|
||
|
||
for db_ent in self.dbs[domain]:
|
||
match = True
|
||
for s, v in constraints.items():
|
||
if s == 'name':
|
||
continue
|
||
if s in ['people', 'stay'] or (domain == 'hotel' and s == 'day') or \
|
||
(domain == 'restaurant' and s in ['day', 'time']):
|
||
# These inform slots belong to "book info",which do not exist in DB
|
||
# "book" is according to the user goal,not DB
|
||
continue
|
||
|
||
skip_case = {
|
||
"don't care": 1,
|
||
"do n't care": 1,
|
||
'dont care': 1,
|
||
'not mentioned': 1,
|
||
'dontcare': 1,
|
||
'': 1
|
||
}
|
||
if skip_case.get(v):
|
||
continue
|
||
|
||
if s not in db_ent:
|
||
# logging.warning('Searching warning: slot %s not in %s db'%(s, domain))
|
||
match = False
|
||
break
|
||
|
||
# v = 'guesthouse' if v == 'guest house' else v
|
||
# v = 'swimmingpool' if v == 'swimming pool' else v
|
||
v = 'yes' if v == 'free' else v
|
||
|
||
if s in ['arrive', 'leave']:
|
||
try:
|
||
h, m = v.split(
|
||
':'
|
||
) # raise error if time value is not xx:xx format
|
||
v = int(h) * 60 + int(m)
|
||
except Exception:
|
||
match = False
|
||
break
|
||
time = int(db_ent[s].split(':')[0]) * 60 + int(
|
||
db_ent[s].split(':')[1])
|
||
if s == 'arrive' and v > time:
|
||
match = False
|
||
if s == 'leave' and v < time:
|
||
match = False
|
||
else:
|
||
if exactly_match and v != db_ent[s]:
|
||
match = False
|
||
break
|
||
elif v not in db_ent[s]:
|
||
match = False
|
||
break
|
||
|
||
if match:
|
||
match_result.append(db_ent)
|
||
|
||
if not return_name:
|
||
return match_result
|
||
else:
|
||
if domain == 'train':
|
||
match_result = [e['id'] for e in match_result]
|
||
else:
|
||
match_result = [e['name'] for e in match_result]
|
||
return match_result
|
||
|
||
def querySQL(self, domain, constraints):
|
||
if not self.sql_dbs:
|
||
for dom in db_domains:
|
||
db = 'db/{}-dbase.db'.format(dom)
|
||
conn = sqlite3.connect(db)
|
||
c = conn.cursor()
|
||
self.sql_dbs[dom] = c
|
||
|
||
sql_query = 'select * from {}'.format(domain)
|
||
|
||
flag = True
|
||
for key, val in constraints.items():
|
||
if val == '' \
|
||
or val == 'dontcare' \
|
||
or val == 'not mentioned' \
|
||
or val == "don't care" \
|
||
or val == 'dont care' \
|
||
or val == "do n't care":
|
||
pass
|
||
else:
|
||
if flag:
|
||
sql_query += ' where '
|
||
val2 = val.replace("'", "''")
|
||
# val2 = normalize(val2)
|
||
if key == 'leaveAt':
|
||
sql_query += r' ' + key + ' > ' + r"'" + val2 + r"'"
|
||
elif key == 'arriveBy':
|
||
sql_query += r' ' + key + ' < ' + r"'" + val2 + r"'"
|
||
else:
|
||
sql_query += r' ' + key + '=' + r"'" + val2 + r"'"
|
||
flag = False
|
||
else:
|
||
val2 = val.replace("'", "''")
|
||
# val2 = normalize(val2)
|
||
if key == 'leaveAt':
|
||
sql_query += r' and ' + key + ' > ' + r"'" + val2 + r"'"
|
||
elif key == 'arriveBy':
|
||
sql_query += r' and ' + key + ' < ' + r"'" + val2 + r"'"
|
||
else:
|
||
sql_query += r' and ' + key + '=' + r"'" + val2 + r"'"
|
||
|
||
try: # "select * from attraction where name = 'queens college'"
|
||
print(sql_query)
|
||
return self.sql_dbs[domain].execute(sql_query).fetchall()
|
||
except Exception:
|
||
return [] # TODO test it
|
||
|
||
|
||
if __name__ == '__main__':
|
||
dbPATHs = {
|
||
'attraction': 'db/attraction_db_processed.json',
|
||
'hospital': 'db/hospital_db_processed.json',
|
||
'hotel': 'db/hotel_db_processed.json',
|
||
'police': 'db/police_db_processed.json',
|
||
'restaurant': 'db/restaurant_db_processed.json',
|
||
'taxi': 'db/taxi_db_processed.json',
|
||
'train': 'db/train_db_processed.json',
|
||
}
|
||
db = MultiWozDB(dbPATHs)
|
||
while True:
|
||
constraints = {}
|
||
inp = input(
|
||
'input belief state in fomat: domain-slot1=value1;slot2=value2...\n'
|
||
)
|
||
domain, cons = inp.split('-')
|
||
for sv in cons.split(';'):
|
||
s, v = sv.split('=')
|
||
constraints[s] = v
|
||
# res = db.querySQL(domain, constraints)
|
||
res = db.queryJsons(domain, constraints, return_name=True)
|
||
report = []
|
||
reidx = {
|
||
'hotel': 8,
|
||
'restaurant': 6,
|
||
'attraction': 5,
|
||
'train': 1,
|
||
}
|
||
print(constraints)
|
||
print(res)
|
||
print('count:', len(res), '\nnames:', report)
|