You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

sentence_transformers_op.py 2.8 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. from sentence_transformers import SentenceTransformer
  2. from sentence_transformers import util
  3. from dora import DoraStatus
  4. import os
  5. import sys
  6. import torch
  7. import pyarrow as pa
  8. SHOULD_BE_INCLUDED = [
  9. "webcam.py",
  10. "object_detection.py",
  11. "plot.py",
  12. ]
  13. ## Get all python files path in given directory
  14. def get_all_functions(path):
  15. raw = []
  16. paths = []
  17. for root, dirs, files in os.walk(path):
  18. for file in files:
  19. if file.endswith(".py"):
  20. if file not in SHOULD_BE_INCLUDED:
  21. continue
  22. path = os.path.join(root, file)
  23. with open(path, "r", encoding="utf8") as f:
  24. ## add file folder to system path
  25. sys.path.append(root)
  26. ## import module from path
  27. raw.append(f.read())
  28. paths.append(path)
  29. return raw, paths
  30. def search(query_embedding, corpus_embeddings, paths, raw, k=5, file_extension=None):
  31. cos_scores = util.cos_sim(query_embedding, corpus_embeddings)[0]
  32. top_results = torch.topk(cos_scores, k=min(k, len(cos_scores)), sorted=True)
  33. out = []
  34. for score, idx in zip(top_results[0], top_results[1]):
  35. out.extend([raw[idx], paths[idx], score])
  36. return out
  37. class Operator:
  38. """ """
  39. def __init__(self):
  40. ## TODO: Add a initialisation step
  41. self.model = SentenceTransformer("BAAI/bge-large-en-v1.5")
  42. self.encoding = []
  43. # file directory
  44. path = os.path.dirname(os.path.abspath(__file__))
  45. self.raw, self.path = get_all_functions(path)
  46. # Encode all files
  47. self.encoding = self.model.encode(self.raw)
  48. def on_event(
  49. self,
  50. dora_event,
  51. send_output,
  52. ) -> DoraStatus:
  53. if dora_event["type"] == "INPUT":
  54. if dora_event["id"] == "query":
  55. values = dora_event["value"].to_pylist()
  56. query_embeddings = self.model.encode(values)
  57. output = search(
  58. query_embeddings,
  59. self.encoding,
  60. self.path,
  61. self.raw,
  62. )
  63. [raw, path, score] = output[0:3]
  64. send_output(
  65. "raw_file",
  66. pa.array([{"raw": raw, "path": path, "user_message": values[0]}]),
  67. dora_event["metadata"],
  68. )
  69. else:
  70. input = dora_event["value"][0].as_py()
  71. index = self.path.index(input["path"])
  72. self.raw[index] = input["raw"]
  73. self.encoding[index] = self.model.encode([input["raw"]])[0]
  74. return DoraStatus.CONTINUE
  75. if __name__ == "__main__":
  76. operator = Operator()