You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

sentence_transformers_op.py 2.9 kB

10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. """TODO: Add docstring."""
  2. import os
  3. import sys
  4. import pyarrow as pa
  5. import torch
  6. from dora import DoraStatus
  7. from sentence_transformers import SentenceTransformer, util
  8. SHOULD_BE_INCLUDED = [
  9. "webcam.py",
  10. "object_detection.py",
  11. "plot.py",
  12. ]
  13. ## Get all python files path in given directory
  14. def get_all_functions(path):
  15. """TODO: Add docstring."""
  16. raw = []
  17. paths = []
  18. for root, dirs, files in os.walk(path):
  19. for file in files:
  20. if file.endswith(".py"):
  21. if file not in SHOULD_BE_INCLUDED:
  22. continue
  23. path = os.path.join(root, file)
  24. with open(path, encoding="utf8") as f:
  25. ## add file folder to system path
  26. sys.path.append(root)
  27. ## import module from path
  28. raw.append(f.read())
  29. paths.append(path)
  30. return raw, paths
  31. def search(query_embedding, corpus_embeddings, paths, raw, k=5, file_extension=None):
  32. """TODO: Add docstring."""
  33. cos_scores = util.cos_sim(query_embedding, corpus_embeddings)[0]
  34. top_results = torch.topk(cos_scores, k=min(k, len(cos_scores)), sorted=True)
  35. out = []
  36. for score, idx in zip(top_results[0], top_results[1]):
  37. out.extend([raw[idx], paths[idx], score])
  38. return out
  39. class Operator:
  40. """ """
  41. def __init__(self):
  42. ## TODO: Add a initialisation step
  43. """TODO: Add docstring."""
  44. self.model = SentenceTransformer("BAAI/bge-large-en-v1.5")
  45. self.encoding = []
  46. # file directory
  47. path = os.path.dirname(os.path.abspath(__file__))
  48. self.raw, self.path = get_all_functions(path)
  49. # Encode all files
  50. self.encoding = self.model.encode(self.raw)
  51. def on_event(
  52. self,
  53. dora_event,
  54. send_output,
  55. ) -> DoraStatus:
  56. """TODO: Add docstring."""
  57. if dora_event["type"] == "INPUT":
  58. if dora_event["id"] == "query":
  59. values = dora_event["value"].to_pylist()
  60. query_embeddings = self.model.encode(values)
  61. output = search(
  62. query_embeddings,
  63. self.encoding,
  64. self.path,
  65. self.raw,
  66. )
  67. [raw, path, score] = output[0:3]
  68. send_output(
  69. "raw_file",
  70. pa.array([{"raw": raw, "path": path, "user_message": values[0]}]),
  71. dora_event["metadata"],
  72. )
  73. else:
  74. input = dora_event["value"][0].as_py()
  75. index = self.path.index(input["path"])
  76. self.raw[index] = input["raw"]
  77. self.encoding[index] = self.model.encode([input["raw"]])[0]
  78. return DoraStatus.CONTINUE
  79. if __name__ == "__main__":
  80. operator = Operator()